1use {
2 crate::{
3 nonblocking::{
4 connection_rate_limiter::{ConnectionRateLimiter, TotalConnectionRateLimiter},
5 stream_throttle::{
6 ConnectionStreamCounter, StakedStreamLoadEMA, STREAM_THROTTLING_INTERVAL,
7 STREAM_THROTTLING_INTERVAL_MS,
8 },
9 },
10 quic::{configure_server, QuicServerError, QuicServerParams, StreamerStats},
11 streamer::StakedNodes,
12 },
13 async_channel::{bounded as async_bounded, Receiver as AsyncReceiver, Sender as AsyncSender},
14 bytes::Bytes,
15 crossbeam_channel::Sender,
16 futures::{stream::FuturesUnordered, Future, StreamExt as _},
17 indexmap::map::{Entry, IndexMap},
18 percentage::Percentage,
19 quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime, VarInt},
20 quinn_proto::VarIntBoundsExceeded,
21 rand::{thread_rng, Rng},
22 smallvec::SmallVec,
23 solana_keypair::Keypair,
24 solana_measure::measure::Measure,
25 solana_packet::{Meta, PACKET_DATA_SIZE},
26 solana_perf::packet::{PacketBatch, PacketBatchRecycler, PACKETS_PER_BATCH},
27 solana_pubkey::Pubkey,
28 solana_quic_definitions::{
29 QUIC_CONNECTION_HANDSHAKE_TIMEOUT, QUIC_MAX_STAKED_CONCURRENT_STREAMS,
30 QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
31 QUIC_MIN_STAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO,
32 QUIC_TOTAL_STAKED_CONCURRENT_STREAMS, QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO,
33 },
34 solana_signature::Signature,
35 solana_time_utils as timing,
36 solana_tls_utils::get_pubkey_from_tls_certificate,
37 solana_transaction_metrics_tracker::signature_if_should_track_packet,
38 std::{
39 array,
40 fmt,
41 iter::repeat_with,
42 net::{IpAddr, SocketAddr, UdpSocket},
43 pin::Pin,
44 sync::{
46 atomic::{AtomicBool, AtomicU64, Ordering},
47 Arc, RwLock,
48 },
49 task::Poll,
50 time::{Duration, Instant},
51 },
52 tokio::{
53 select,
62 sync::{Mutex, MutexGuard},
63 task::JoinHandle,
64 time::{sleep, timeout},
65 },
66 tokio_util::sync::CancellationToken,
67};
68
69pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
70
71pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
72
73const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
74const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
75
76const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
77const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
78
79const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
80const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] = b"exceed_max_stream_count";
81
82const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
83const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
84
85const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
86const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
87
88#[deprecated(
92 since = "2.2.0",
93 note = "Use solana_streamer::quic::DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE"
94)]
95pub use crate::quic::DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE;
96#[deprecated(
98 since = "2.2.0",
99 note = "Use solana_streamer::quic::DEFAULT_MAX_STREAMS_PER_MS"
100)]
101pub use crate::quic::DEFAULT_MAX_STREAMS_PER_MS;
102
103const TOTAL_CONNECTIONS_PER_SECOND: u64 = 2500;
107
108const CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD: usize = 100_000;
112
113#[derive(Clone)]
119struct PacketAccumulator {
120 pub meta: Meta,
121 pub chunks: SmallVec<[Bytes; 2]>,
122 pub start_time: Instant,
123}
124
125impl PacketAccumulator {
126 fn new(meta: Meta) -> Self {
127 Self {
128 meta,
129 chunks: SmallVec::default(),
130 start_time: Instant::now(),
131 }
132 }
133}
134
135#[derive(Copy, Clone, Debug)]
136pub enum ConnectionPeerType {
137 Unstaked,
138 Staked(u64),
139}
140
141impl ConnectionPeerType {
142 pub(crate) fn is_staked(&self) -> bool {
143 matches!(self, ConnectionPeerType::Staked(_))
144 }
145}
146
147pub struct SpawnNonBlockingServerResult {
148 pub endpoints: Vec<Endpoint>,
149 pub stats: Arc<StreamerStats>,
150 pub thread: JoinHandle<()>,
151 pub max_concurrent_connections: usize,
152}
153
154pub fn spawn_server(
155 name: &'static str,
156 sock: UdpSocket,
157 keypair: &Keypair,
158 packet_sender: Sender<PacketBatch>,
159 exit: Arc<AtomicBool>,
160 staked_nodes: Arc<RwLock<StakedNodes>>,
161 quic_server_params: QuicServerParams,
162) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
163 spawn_server_multi(
164 name,
165 vec![sock],
166 keypair,
167 packet_sender,
168 exit,
169 staked_nodes,
170 quic_server_params,
171 )
172}
173
174pub fn spawn_server_multi(
175 name: &'static str,
176 sockets: Vec<UdpSocket>,
177 keypair: &Keypair,
178 packet_sender: Sender<PacketBatch>,
179 exit: Arc<AtomicBool>,
180 staked_nodes: Arc<RwLock<StakedNodes>>,
181 quic_server_params: QuicServerParams,
182) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
183 info!("Start {name} quic server on {sockets:?}");
184 let QuicServerParams {
185 max_unstaked_connections,
186 max_staked_connections,
187 max_connections_per_peer,
188 max_streams_per_ms,
189 max_connections_per_ipaddr_per_min,
190 wait_for_chunk_timeout,
191 coalesce,
192 coalesce_channel_size,
193 } = quic_server_params;
194 let concurrent_connections = max_staked_connections + max_unstaked_connections;
195 let max_concurrent_connections = concurrent_connections + concurrent_connections / 4;
196 let (config, _) = configure_server(keypair)?;
197
198 let endpoints = sockets
199 .into_iter()
200 .map(|sock| {
201 Endpoint::new(
202 EndpointConfig::default(),
203 Some(config.clone()),
204 sock,
205 Arc::new(TokioRuntime),
206 )
207 .map_err(QuicServerError::EndpointFailed)
208 })
209 .collect::<Result<Vec<_>, _>>()?;
210 let stats = Arc::<StreamerStats>::default();
211 let handle = tokio::spawn(run_server(
212 name,
213 endpoints.clone(),
214 packet_sender,
215 exit,
216 max_connections_per_peer,
217 staked_nodes,
218 max_staked_connections,
219 max_unstaked_connections,
220 max_streams_per_ms,
221 max_connections_per_ipaddr_per_min,
222 stats.clone(),
223 wait_for_chunk_timeout,
224 coalesce,
225 coalesce_channel_size,
226 max_concurrent_connections,
227 ));
228 Ok(SpawnNonBlockingServerResult {
229 endpoints,
230 stats,
231 thread: handle,
232 max_concurrent_connections,
233 })
234}
235
236struct ClientConnectionTracker {
241 stats: Arc<StreamerStats>,
242}
243
244impl fmt::Debug for ClientConnectionTracker {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 f.debug_struct("StreamerClientConnection")
248 .field(
249 "open_connections:",
250 &self.stats.open_connections.load(Ordering::Relaxed),
251 )
252 .finish()
253 }
254}
255
256impl Drop for ClientConnectionTracker {
257 fn drop(&mut self) {
259 self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
260 }
261}
262
263impl ClientConnectionTracker {
264 fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
267 let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
268 if open_connections >= max_concurrent_connections {
269 stats.open_connections.fetch_sub(1, Ordering::Relaxed);
270 debug!(
271 "There are too many concurrent connections opened already: open: {}, max: {}",
272 open_connections, max_concurrent_connections
273 );
274 return Err(());
275 }
276
277 Ok(Self { stats })
278 }
279}
280
281#[allow(clippy::too_many_arguments)]
282async fn run_server(
283 name: &'static str,
284 endpoints: Vec<Endpoint>,
285 packet_sender: Sender<PacketBatch>,
286 exit: Arc<AtomicBool>,
287 max_connections_per_peer: usize,
288 staked_nodes: Arc<RwLock<StakedNodes>>,
289 max_staked_connections: usize,
290 max_unstaked_connections: usize,
291 max_streams_per_ms: u64,
292 max_connections_per_ipaddr_per_min: u64,
293 stats: Arc<StreamerStats>,
294 wait_for_chunk_timeout: Duration,
295 coalesce: Duration,
296 coalesce_channel_size: usize,
297 max_concurrent_connections: usize,
298) {
299 let rate_limiter = ConnectionRateLimiter::new(max_connections_per_ipaddr_per_min);
300 let overall_connection_rate_limiter =
301 TotalConnectionRateLimiter::new(TOTAL_CONNECTIONS_PER_SECOND);
302
303 const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
304 debug!("spawn quic server");
305 let mut last_datapoint = Instant::now();
306 let unstaked_connection_table: Arc<Mutex<ConnectionTable>> =
307 Arc::new(Mutex::new(ConnectionTable::new()));
308 let stream_load_ema = Arc::new(StakedStreamLoadEMA::new(
309 stats.clone(),
310 max_unstaked_connections,
311 max_streams_per_ms,
312 ));
313 stats
314 .quic_endpoints_count
315 .store(endpoints.len(), Ordering::Relaxed);
316 let staked_connection_table: Arc<Mutex<ConnectionTable>> =
317 Arc::new(Mutex::new(ConnectionTable::new()));
318 let (sender, receiver) = async_bounded(coalesce_channel_size);
319 tokio::spawn(packet_batch_sender(
320 packet_sender,
321 receiver,
322 exit.clone(),
323 stats.clone(),
324 coalesce,
325 ));
326
327 let mut accepts = endpoints
328 .iter()
329 .enumerate()
330 .map(|(i, incoming)| {
331 Box::pin(EndpointAccept {
332 accept: incoming.accept(),
333 endpoint: i,
334 })
335 })
336 .collect::<FuturesUnordered<_>>();
337
338 while !exit.load(Ordering::Relaxed) {
339 let timeout_connection = select! {
340 ready = accepts.next() => {
341 if let Some((connecting, i)) = ready {
342 accepts.push(
343 Box::pin(EndpointAccept {
344 accept: endpoints[i].accept(),
345 endpoint: i,
346 }
347 ));
348 Ok(connecting)
349 } else {
350 continue
352 }
353 }
354 _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
355 Err(())
356 }
357 };
358
359 if last_datapoint.elapsed().as_secs() >= 5 {
360 stats.report(name);
361 last_datapoint = Instant::now();
362 }
363
364 if let Ok(Some(incoming)) = timeout_connection {
365 stats
366 .total_incoming_connection_attempts
367 .fetch_add(1, Ordering::Relaxed);
368
369 let remote_address = incoming.remote_address();
370
371 if rate_limiter.len() > CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD {
373 rate_limiter.retain_recent();
374 }
375 stats
376 .connection_rate_limiter_length
377 .store(rate_limiter.len(), Ordering::Relaxed);
378 debug!("Got a connection {remote_address:?}");
379 if !rate_limiter.is_allowed(&remote_address.ip()) {
380 debug!(
381 "Reject connection from {:?} -- rate limiting exceeded",
382 remote_address
383 );
384 stats
385 .connection_rate_limited_per_ipaddr
386 .fetch_add(1, Ordering::Relaxed);
387 incoming.ignore();
388 continue;
389 }
390
391 if !overall_connection_rate_limiter.is_allowed() {
393 debug!(
394 "Reject connection from {:?} -- total rate limiting exceeded",
395 remote_address.ip()
396 );
397 stats
398 .connection_rate_limited_across_all
399 .fetch_add(1, Ordering::Relaxed);
400 incoming.ignore();
401 continue;
402 }
403
404 let Ok(client_connection_tracker) =
405 ClientConnectionTracker::new(stats.clone(), max_concurrent_connections)
406 else {
407 stats
408 .refused_connections_too_many_open_connections
409 .fetch_add(1, Ordering::Relaxed);
410 incoming.refuse();
411 continue;
412 };
413
414 stats
415 .outstanding_incoming_connection_attempts
416 .fetch_add(1, Ordering::Relaxed);
417 let connecting = incoming.accept();
418 match connecting {
419 Ok(connecting) => {
420 tokio::spawn(setup_connection(
421 connecting,
422 client_connection_tracker,
423 unstaked_connection_table.clone(),
424 staked_connection_table.clone(),
425 sender.clone(),
426 max_connections_per_peer,
427 staked_nodes.clone(),
428 max_staked_connections,
429 max_unstaked_connections,
430 max_streams_per_ms,
431 stats.clone(),
432 wait_for_chunk_timeout,
433 stream_load_ema.clone(),
434 ));
435 }
436 Err(err) => {
437 debug!("Incoming::accept(): error {:?}", err);
438 }
439 }
440 } else {
441 debug!("accept(): Timed out waiting for connection");
442 }
443 }
444}
445
446fn prune_unstaked_connection_table(
447 unstaked_connection_table: &mut ConnectionTable,
448 max_unstaked_connections: usize,
449 stats: Arc<StreamerStats>,
450) {
451 if unstaked_connection_table.total_size >= max_unstaked_connections {
452 const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
453 let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
454
455 let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
456 let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
457 stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
458 }
459}
460
461pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
462 connection
464 .peer_identity()?
465 .downcast::<Vec<rustls::pki_types::CertificateDer>>()
466 .ok()
467 .filter(|certs| certs.len() == 1)?
468 .first()
469 .and_then(get_pubkey_from_tls_certificate)
470}
471
472fn get_connection_stake(
473 connection: &Connection,
474 staked_nodes: &RwLock<StakedNodes>,
475) -> Option<(Pubkey, u64, u64, u64, u64)> {
476 let pubkey = get_remote_pubkey(connection)?;
477 debug!("Peer public key is {pubkey:?}");
478 let staked_nodes = staked_nodes.read().unwrap();
479 Some((
480 pubkey,
481 staked_nodes.get_node_stake(&pubkey)?,
482 staked_nodes.total_stake(),
483 staked_nodes.max_stake(),
484 staked_nodes.min_stake(),
485 ))
486}
487
488pub fn compute_max_allowed_uni_streams(peer_type: ConnectionPeerType, total_stake: u64) -> usize {
489 match peer_type {
490 ConnectionPeerType::Staked(peer_stake) => {
491 if total_stake == 0 || peer_stake > total_stake {
493 warn!(
494 "Invalid stake values: peer_stake: {:?}, total_stake: {:?}",
495 peer_stake, total_stake,
496 );
497
498 QUIC_MIN_STAKED_CONCURRENT_STREAMS
499 } else {
500 let delta = (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS
501 - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
502
503 (((peer_stake as f64 / total_stake as f64) * delta) as usize
504 + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
505 .clamp(
506 QUIC_MIN_STAKED_CONCURRENT_STREAMS,
507 QUIC_MAX_STAKED_CONCURRENT_STREAMS,
508 )
509 }
510 }
511 ConnectionPeerType::Unstaked => QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
512 }
513}
514
515enum ConnectionHandlerError {
516 ConnectionAddError,
517 MaxStreamError,
518}
519
520#[derive(Clone)]
521struct NewConnectionHandlerParams {
522 packet_sender: AsyncSender<PacketAccumulator>,
528 remote_pubkey: Option<Pubkey>,
529 peer_type: ConnectionPeerType,
530 total_stake: u64,
531 max_connections_per_peer: usize,
532 stats: Arc<StreamerStats>,
533 max_stake: u64,
534 min_stake: u64,
535}
536
537impl NewConnectionHandlerParams {
538 fn new_unstaked(
539 packet_sender: AsyncSender<PacketAccumulator>,
540 max_connections_per_peer: usize,
541 stats: Arc<StreamerStats>,
542 ) -> NewConnectionHandlerParams {
543 NewConnectionHandlerParams {
544 packet_sender,
545 remote_pubkey: None,
546 peer_type: ConnectionPeerType::Unstaked,
547 total_stake: 0,
548 max_connections_per_peer,
549 stats,
550 max_stake: 0,
551 min_stake: 0,
552 }
553 }
554}
555
556fn handle_and_cache_new_connection(
557 client_connection_tracker: ClientConnectionTracker,
558 connection: Connection,
559 mut connection_table_l: MutexGuard<ConnectionTable>,
560 connection_table: Arc<Mutex<ConnectionTable>>,
561 params: &NewConnectionHandlerParams,
562 wait_for_chunk_timeout: Duration,
563 stream_load_ema: Arc<StakedStreamLoadEMA>,
564) -> Result<(), ConnectionHandlerError> {
565 if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
566 params.peer_type,
567 params.total_stake,
568 ) as u64)
569 {
570 let remote_addr = connection.remote_address();
571 let receive_window =
572 compute_recieve_window(params.max_stake, params.min_stake, params.peer_type);
573
574 debug!(
575 "Peer type {:?}, total stake {}, max streams {} receive_window {:?} from peer {}",
576 params.peer_type,
577 params.total_stake,
578 max_uni_streams.into_inner(),
579 receive_window,
580 remote_addr,
581 );
582
583 if let Some((last_update, cancel_connection, stream_counter)) = connection_table_l
584 .try_add_connection(
585 ConnectionTableKey::new(remote_addr.ip(), params.remote_pubkey),
586 remote_addr.port(),
587 client_connection_tracker,
588 Some(connection.clone()),
589 params.peer_type,
590 timing::timestamp(),
591 params.max_connections_per_peer,
592 )
593 {
594 drop(connection_table_l);
595
596 if let Ok(receive_window) = receive_window {
597 connection.set_receive_window(receive_window);
598 }
599 connection.set_max_concurrent_uni_streams(max_uni_streams);
600
601 tokio::spawn(handle_connection(
602 connection,
603 remote_addr,
604 last_update,
605 connection_table,
606 cancel_connection,
607 params.clone(),
608 wait_for_chunk_timeout,
609 stream_load_ema,
610 stream_counter,
611 ));
612 Ok(())
613 } else {
614 params
615 .stats
616 .connection_add_failed
617 .fetch_add(1, Ordering::Relaxed);
618 Err(ConnectionHandlerError::ConnectionAddError)
619 }
620 } else {
621 connection.close(
622 CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT.into(),
623 CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT,
624 );
625 params
626 .stats
627 .connection_add_failed_invalid_stream_count
628 .fetch_add(1, Ordering::Relaxed);
629 Err(ConnectionHandlerError::MaxStreamError)
630 }
631}
632
633async fn prune_unstaked_connections_and_add_new_connection(
634 client_connection_tracker: ClientConnectionTracker,
635 connection: Connection,
636 connection_table: Arc<Mutex<ConnectionTable>>,
637 max_connections: usize,
638 params: &NewConnectionHandlerParams,
639 wait_for_chunk_timeout: Duration,
640 stream_load_ema: Arc<StakedStreamLoadEMA>,
641) -> Result<(), ConnectionHandlerError> {
642 let stats = params.stats.clone();
643 if max_connections > 0 {
644 let connection_table_clone = connection_table.clone();
645 let mut connection_table = connection_table.lock().await;
646 prune_unstaked_connection_table(&mut connection_table, max_connections, stats);
647 handle_and_cache_new_connection(
648 client_connection_tracker,
649 connection,
650 connection_table,
651 connection_table_clone,
652 params,
653 wait_for_chunk_timeout,
654 stream_load_ema,
655 )
656 } else {
657 connection.close(
658 CONNECTION_CLOSE_CODE_DISALLOWED.into(),
659 CONNECTION_CLOSE_REASON_DISALLOWED,
660 );
661 Err(ConnectionHandlerError::ConnectionAddError)
662 }
663}
664
665fn compute_receive_window_ratio_for_staked_node(max_stake: u64, min_stake: u64, stake: u64) -> u64 {
667 if stake > max_stake {
676 return QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
677 }
678
679 let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
680 let min_ratio = QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO;
681 if max_stake > min_stake {
682 let a = (max_ratio - min_ratio) as f64 / (max_stake - min_stake) as f64;
683 let b = max_ratio as f64 - ((max_stake as f64) * a);
684 let ratio = (a * stake as f64) + b;
685 ratio.round() as u64
686 } else {
687 QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO
688 }
689}
690
691fn compute_recieve_window(
692 max_stake: u64,
693 min_stake: u64,
694 peer_type: ConnectionPeerType,
695) -> Result<VarInt, VarIntBoundsExceeded> {
696 match peer_type {
697 ConnectionPeerType::Unstaked => {
698 VarInt::from_u64(PACKET_DATA_SIZE as u64 * QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO)
699 }
700 ConnectionPeerType::Staked(peer_stake) => {
701 let ratio =
702 compute_receive_window_ratio_for_staked_node(max_stake, min_stake, peer_stake);
703 VarInt::from_u64(PACKET_DATA_SIZE as u64 * ratio)
704 }
705 }
706}
707
708#[allow(clippy::too_many_arguments)]
709async fn setup_connection(
710 connecting: Connecting,
711 client_connection_tracker: ClientConnectionTracker,
712 unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
713 staked_connection_table: Arc<Mutex<ConnectionTable>>,
714 packet_sender: AsyncSender<PacketAccumulator>,
715 max_connections_per_peer: usize,
716 staked_nodes: Arc<RwLock<StakedNodes>>,
717 max_staked_connections: usize,
718 max_unstaked_connections: usize,
719 max_streams_per_ms: u64,
720 stats: Arc<StreamerStats>,
721 wait_for_chunk_timeout: Duration,
722 stream_load_ema: Arc<StakedStreamLoadEMA>,
723) {
724 const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
725 let from = connecting.remote_address();
726 let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
727 stats
728 .outstanding_incoming_connection_attempts
729 .fetch_sub(1, Ordering::Relaxed);
730 if let Ok(connecting_result) = res {
731 match connecting_result {
732 Ok(new_connection) => {
733 stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
734
735 let params = get_connection_stake(&new_connection, &staked_nodes).map_or(
736 NewConnectionHandlerParams::new_unstaked(
737 packet_sender.clone(),
738 max_connections_per_peer,
739 stats.clone(),
740 ),
741 |(pubkey, stake, total_stake, max_stake, min_stake)| {
742 let min_stake_ratio =
745 1_f64 / (max_streams_per_ms * STREAM_THROTTLING_INTERVAL_MS) as f64;
746 let stake_ratio = stake as f64 / total_stake as f64;
747 let peer_type = if stake_ratio < min_stake_ratio {
748 ConnectionPeerType::Unstaked
750 } else {
751 ConnectionPeerType::Staked(stake)
752 };
753 NewConnectionHandlerParams {
754 packet_sender,
755 remote_pubkey: Some(pubkey),
756 peer_type,
757 total_stake,
758 max_connections_per_peer,
759 stats: stats.clone(),
760 max_stake,
761 min_stake,
762 }
763 },
764 );
765
766 match params.peer_type {
767 ConnectionPeerType::Staked(stake) => {
768 let mut connection_table_l = staked_connection_table.lock().await;
769
770 if connection_table_l.total_size >= max_staked_connections {
771 let num_pruned =
772 connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake);
773 stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
774 }
775
776 if connection_table_l.total_size < max_staked_connections {
777 if let Ok(()) = handle_and_cache_new_connection(
778 client_connection_tracker,
779 new_connection,
780 connection_table_l,
781 staked_connection_table.clone(),
782 ¶ms,
783 wait_for_chunk_timeout,
784 stream_load_ema.clone(),
785 ) {
786 stats
787 .connection_added_from_staked_peer
788 .fetch_add(1, Ordering::Relaxed);
789 }
790 } else {
791 if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
795 client_connection_tracker,
796 new_connection,
797 unstaked_connection_table.clone(),
798 max_unstaked_connections,
799 ¶ms,
800 wait_for_chunk_timeout,
801 stream_load_ema.clone(),
802 )
803 .await
804 {
805 stats
806 .connection_added_from_staked_peer
807 .fetch_add(1, Ordering::Relaxed);
808 } else {
809 stats
810 .connection_add_failed_on_pruning
811 .fetch_add(1, Ordering::Relaxed);
812 stats
813 .connection_add_failed_staked_node
814 .fetch_add(1, Ordering::Relaxed);
815 }
816 }
817 }
818 ConnectionPeerType::Unstaked => {
819 if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
820 client_connection_tracker,
821 new_connection,
822 unstaked_connection_table.clone(),
823 max_unstaked_connections,
824 ¶ms,
825 wait_for_chunk_timeout,
826 stream_load_ema.clone(),
827 )
828 .await
829 {
830 stats
831 .connection_added_from_unstaked_peer
832 .fetch_add(1, Ordering::Relaxed);
833 } else {
834 stats
835 .connection_add_failed_unstaked_node
836 .fetch_add(1, Ordering::Relaxed);
837 }
838 }
839 }
840 }
841 Err(e) => {
842 handle_connection_error(e, &stats, from);
843 }
844 }
845 } else {
846 stats
847 .connection_setup_timeout
848 .fetch_add(1, Ordering::Relaxed);
849 }
850}
851
852fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
853 debug!("error: {:?} from: {:?}", e, from);
854 stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
855 match e {
856 quinn::ConnectionError::TimedOut => {
857 stats
858 .connection_setup_error_timed_out
859 .fetch_add(1, Ordering::Relaxed);
860 }
861 quinn::ConnectionError::ConnectionClosed(_) => {
862 stats
863 .connection_setup_error_closed
864 .fetch_add(1, Ordering::Relaxed);
865 }
866 quinn::ConnectionError::TransportError(_) => {
867 stats
868 .connection_setup_error_transport
869 .fetch_add(1, Ordering::Relaxed);
870 }
871 quinn::ConnectionError::ApplicationClosed(_) => {
872 stats
873 .connection_setup_error_app_closed
874 .fetch_add(1, Ordering::Relaxed);
875 }
876 quinn::ConnectionError::Reset => {
877 stats
878 .connection_setup_error_reset
879 .fetch_add(1, Ordering::Relaxed);
880 }
881 quinn::ConnectionError::LocallyClosed => {
882 stats
883 .connection_setup_error_locally_closed
884 .fetch_add(1, Ordering::Relaxed);
885 }
886 _ => {}
887 }
888}
889
890async fn packet_batch_sender(
893 packet_sender: Sender<PacketBatch>,
894 packet_receiver: AsyncReceiver<PacketAccumulator>,
895 exit: Arc<AtomicBool>,
896 stats: Arc<StreamerStats>,
897 coalesce: Duration,
898) {
899 trace!("enter packet_batch_sender");
900 let recycler = PacketBatchRecycler::default();
901 let mut batch_start_time = Instant::now();
902 loop {
903 let mut packet_perf_measure: Vec<([u8; 64], Instant)> = Vec::default();
904 let mut packet_batch =
905 PacketBatch::new_with_recycler(&recycler, PACKETS_PER_BATCH, "quic_packet_coalescer");
906 let mut total_bytes: usize = 0;
907
908 stats
909 .total_packet_batches_allocated
910 .fetch_add(1, Ordering::Relaxed);
911 stats
912 .total_packets_allocated
913 .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed);
914
915 loop {
916 if exit.load(Ordering::Relaxed) {
917 return;
918 }
919 let elapsed = batch_start_time.elapsed();
920 if packet_batch.len() >= PACKETS_PER_BATCH
921 || (!packet_batch.is_empty() && elapsed >= coalesce)
922 {
923 let len = packet_batch.len();
924 track_streamer_fetch_packet_performance(&packet_perf_measure, &stats);
925
926 if let Err(e) = packet_sender.send(packet_batch) {
927 stats
928 .total_packet_batch_send_err
929 .fetch_add(1, Ordering::Relaxed);
930 trace!("Send error: {}", e);
931 } else {
932 stats
933 .total_packet_batches_sent
934 .fetch_add(1, Ordering::Relaxed);
935
936 stats
937 .total_packets_sent_to_consumer
938 .fetch_add(len, Ordering::Relaxed);
939
940 stats
941 .total_bytes_sent_to_consumer
942 .fetch_add(total_bytes, Ordering::Relaxed);
943
944 trace!("Sent {} packet batch", len);
945 }
946 break;
947 }
948
949 let timeout_res = if !packet_batch.is_empty() {
950 timeout(coalesce - elapsed, packet_receiver.recv()).await
952 } else {
953 Ok(packet_receiver.recv().await)
961 };
962
963 if let Ok(Ok(packet_accumulator)) = timeout_res {
964 if packet_batch.is_empty() {
966 batch_start_time = Instant::now();
967 }
968
969 unsafe {
970 packet_batch.set_len(packet_batch.len() + 1);
971 }
972
973 let i = packet_batch.len() - 1;
974 *packet_batch[i].meta_mut() = packet_accumulator.meta;
975 let num_chunks = packet_accumulator.chunks.len();
976 let mut offset = 0;
977 for chunk in packet_accumulator.chunks {
978 packet_batch[i].buffer_mut()[offset..offset + chunk.len()]
979 .copy_from_slice(&chunk);
980 offset += chunk.len();
981 }
982
983 total_bytes += packet_batch[i].meta().size;
984
985 if let Some(signature) = signature_if_should_track_packet(&packet_batch[i])
986 .ok()
987 .flatten()
988 {
989 packet_perf_measure.push((*signature, packet_accumulator.start_time));
990 packet_batch[i].meta_mut().set_track_performance(true);
992 }
993 stats
994 .total_chunks_processed_by_batcher
995 .fetch_add(num_chunks, Ordering::Relaxed);
996 }
997 }
998 }
999}
1000
1001fn track_streamer_fetch_packet_performance(
1002 packet_perf_measure: &[([u8; 64], Instant)],
1003 stats: &StreamerStats,
1004) {
1005 if packet_perf_measure.is_empty() {
1006 return;
1007 }
1008 let mut measure = Measure::start("track_perf");
1009 let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
1010
1011 let now = Instant::now();
1012 for (signature, start_time) in packet_perf_measure {
1013 let duration = now.duration_since(*start_time);
1014 debug!(
1015 "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
1016 Signature::from(*signature)
1017 );
1018 process_sampled_packets_us_hist
1019 .increment(duration.as_micros() as u64)
1020 .unwrap();
1021 }
1022
1023 drop(process_sampled_packets_us_hist);
1024 measure.stop();
1025 stats
1026 .perf_track_overhead_us
1027 .fetch_add(measure.as_us(), Ordering::Relaxed);
1028}
1029
1030async fn handle_connection(
1031 connection: Connection,
1032 remote_addr: SocketAddr,
1033 last_update: Arc<AtomicU64>,
1034 connection_table: Arc<Mutex<ConnectionTable>>,
1035 cancel: CancellationToken,
1036 params: NewConnectionHandlerParams,
1037 wait_for_chunk_timeout: Duration,
1038 stream_load_ema: Arc<StakedStreamLoadEMA>,
1039 stream_counter: Arc<ConnectionStreamCounter>,
1040) {
1041 let NewConnectionHandlerParams {
1042 packet_sender,
1043 peer_type,
1044 remote_pubkey,
1045 stats,
1046 total_stake,
1047 ..
1048 } = params;
1049
1050 debug!(
1051 "quic new connection {} streams: {} connections: {}",
1052 remote_addr,
1053 stats.total_streams.load(Ordering::Relaxed),
1054 stats.total_connections.load(Ordering::Relaxed),
1055 );
1056 stats.total_connections.fetch_add(1, Ordering::Relaxed);
1057
1058 'conn: loop {
1059 let mut stream = select! {
1062 stream = connection.accept_uni() => match stream {
1063 Ok(stream) => stream,
1064 Err(e) => {
1065 debug!("stream error: {:?}", e);
1066 break;
1067 }
1068 },
1069 _ = cancel.cancelled() => break,
1070 };
1071
1072 let max_streams_per_throttling_interval =
1073 stream_load_ema.available_load_capacity_in_throttling_duration(peer_type, total_stake);
1074
1075 let throttle_interval_start = stream_counter.reset_throttling_params_if_needed();
1076 let streams_read_in_throttle_interval = stream_counter.stream_count.load(Ordering::Relaxed);
1077 if streams_read_in_throttle_interval >= max_streams_per_throttling_interval {
1078 let throttle_duration =
1081 STREAM_THROTTLING_INTERVAL.saturating_sub(throttle_interval_start.elapsed());
1082
1083 if !throttle_duration.is_zero() {
1084 debug!("Throttling stream from {remote_addr:?}, peer type: {:?}, total stake: {}, \
1085 max_streams_per_interval: {max_streams_per_throttling_interval}, read_interval_streams: {streams_read_in_throttle_interval} \
1086 throttle_duration: {throttle_duration:?}",
1087 peer_type, total_stake);
1088 stats.throttled_streams.fetch_add(1, Ordering::Relaxed);
1089 match peer_type {
1090 ConnectionPeerType::Unstaked => {
1091 stats
1092 .throttled_unstaked_streams
1093 .fetch_add(1, Ordering::Relaxed);
1094 }
1095 ConnectionPeerType::Staked(_) => {
1096 stats
1097 .throttled_staked_streams
1098 .fetch_add(1, Ordering::Relaxed);
1099 }
1100 }
1101 sleep(throttle_duration).await;
1102 }
1103 }
1104 stream_load_ema.increment_load(peer_type);
1105 stream_counter.stream_count.fetch_add(1, Ordering::Relaxed);
1106 stats.total_streams.fetch_add(1, Ordering::Relaxed);
1107 stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
1108
1109 let mut meta = Meta::default();
1110 meta.set_socket_addr(&remote_addr);
1111 meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
1112 let mut accum = PacketAccumulator::new(meta);
1113
1114 let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
1123
1124 loop {
1125 let n_chunks = match tokio::select! {
1129 chunk = tokio::time::timeout(
1130 wait_for_chunk_timeout,
1131 stream.read_chunks(&mut chunks)) => chunk,
1132
1133 _ = cancel.cancelled() => break,
1135 } {
1136 Ok(Ok(chunk)) => chunk.unwrap_or(0),
1138 Ok(Err(e)) => {
1140 debug!("Received stream error: {:?}", e);
1141 stats
1142 .total_stream_read_errors
1143 .fetch_add(1, Ordering::Relaxed);
1144 break;
1145 }
1146 Err(_) => {
1148 debug!("Timeout in receiving on stream");
1149 stats
1150 .total_stream_read_timeouts
1151 .fetch_add(1, Ordering::Relaxed);
1152 break;
1153 }
1154 };
1155
1156 match handle_chunks(
1157 chunks.iter().take(n_chunks).cloned(),
1159 &mut accum,
1160 &packet_sender,
1161 &stats,
1162 peer_type,
1163 )
1164 .await
1165 {
1166 Ok(StreamState::Finished) => {
1168 last_update.store(timing::timestamp(), Ordering::Relaxed);
1169 break;
1170 }
1171 Ok(StreamState::Receiving) => {}
1173 Err(_) => {
1174 connection.close(
1176 CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
1177 CONNECTION_CLOSE_REASON_INVALID_STREAM,
1178 );
1179 stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1180 stream_load_ema.update_ema_if_needed();
1181 break 'conn;
1182 }
1183 }
1184 }
1185
1186 stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1187 stream_load_ema.update_ema_if_needed();
1188 }
1189
1190 let stable_id = connection.stable_id();
1191 let removed_connection_count = connection_table.lock().await.remove_connection(
1192 ConnectionTableKey::new(remote_addr.ip(), remote_pubkey),
1193 remote_addr.port(),
1194 stable_id,
1195 );
1196 if removed_connection_count > 0 {
1197 stats
1198 .connection_removed
1199 .fetch_add(removed_connection_count, Ordering::Relaxed);
1200 } else {
1201 stats
1202 .connection_remove_failed
1203 .fetch_add(1, Ordering::Relaxed);
1204 }
1205 stats.total_connections.fetch_sub(1, Ordering::Relaxed);
1206}
1207
1208enum StreamState {
1209 Receiving,
1211 Finished,
1213}
1214
1215async fn handle_chunks(
1220 chunks: impl ExactSizeIterator<Item = Bytes>,
1221 accum: &mut PacketAccumulator,
1222 packet_sender: &AsyncSender<PacketAccumulator>,
1223 stats: &StreamerStats,
1224 peer_type: ConnectionPeerType,
1225) -> Result<StreamState, ()> {
1226 let n_chunks = chunks.len();
1227 for chunk in chunks {
1228 accum.meta.size += chunk.len();
1229 if accum.meta.size > PACKET_DATA_SIZE {
1230 stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
1234 debug!("invalid stream size {}", accum.meta.size);
1235 return Err(());
1236 }
1237 accum.chunks.push(chunk);
1238 if peer_type.is_staked() {
1239 stats
1240 .total_staked_chunks_received
1241 .fetch_add(1, Ordering::Relaxed);
1242 } else {
1243 stats
1244 .total_unstaked_chunks_received
1245 .fetch_add(1, Ordering::Relaxed);
1246 }
1247 }
1248
1249 if n_chunks != 0 {
1251 return Ok(StreamState::Receiving);
1252 }
1253
1254 if accum.chunks.is_empty() {
1255 debug!("stream is empty");
1256 stats
1257 .total_packet_batches_none
1258 .fetch_add(1, Ordering::Relaxed);
1259 return Err(());
1260 }
1261
1262 let bytes_sent = accum.meta.size;
1264 let chunks_sent = accum.chunks.len();
1265
1266 if let Err(err) = packet_sender.send(accum.clone()).await {
1267 stats
1268 .total_handle_chunk_to_packet_batcher_send_err
1269 .fetch_add(1, Ordering::Relaxed);
1270 trace!("packet batch send error {:?}", err);
1271 } else {
1272 stats
1273 .total_packets_sent_for_batching
1274 .fetch_add(1, Ordering::Relaxed);
1275 stats
1276 .total_bytes_sent_for_batching
1277 .fetch_add(bytes_sent, Ordering::Relaxed);
1278 stats
1279 .total_chunks_sent_for_batching
1280 .fetch_add(chunks_sent, Ordering::Relaxed);
1281
1282 match peer_type {
1283 ConnectionPeerType::Unstaked => {
1284 stats
1285 .total_unstaked_packets_sent_for_batching
1286 .fetch_add(1, Ordering::Relaxed);
1287 }
1288 ConnectionPeerType::Staked(_) => {
1289 stats
1290 .total_staked_packets_sent_for_batching
1291 .fetch_add(1, Ordering::Relaxed);
1292 }
1293 }
1294
1295 trace!("sent {} byte packet for batching", bytes_sent);
1296 }
1297
1298 Ok(StreamState::Finished)
1299}
1300
1301#[derive(Debug)]
1302struct ConnectionEntry {
1303 cancel: CancellationToken,
1304 peer_type: ConnectionPeerType,
1305 last_update: Arc<AtomicU64>,
1306 port: u16,
1307 _client_connection_tracker: ClientConnectionTracker,
1309 connection: Option<Connection>,
1310 stream_counter: Arc<ConnectionStreamCounter>,
1311}
1312
1313impl ConnectionEntry {
1314 fn new(
1315 cancel: CancellationToken,
1316 peer_type: ConnectionPeerType,
1317 last_update: Arc<AtomicU64>,
1318 port: u16,
1319 client_connection_tracker: ClientConnectionTracker,
1320 connection: Option<Connection>,
1321 stream_counter: Arc<ConnectionStreamCounter>,
1322 ) -> Self {
1323 Self {
1324 cancel,
1325 peer_type,
1326 last_update,
1327 port,
1328 _client_connection_tracker: client_connection_tracker,
1329 connection,
1330 stream_counter,
1331 }
1332 }
1333
1334 fn last_update(&self) -> u64 {
1335 self.last_update.load(Ordering::Relaxed)
1336 }
1337
1338 fn stake(&self) -> u64 {
1339 match self.peer_type {
1340 ConnectionPeerType::Unstaked => 0,
1341 ConnectionPeerType::Staked(stake) => stake,
1342 }
1343 }
1344}
1345
1346impl Drop for ConnectionEntry {
1347 fn drop(&mut self) {
1348 if let Some(conn) = self.connection.take() {
1349 conn.close(
1350 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1351 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1352 );
1353 }
1354 self.cancel.cancel();
1355 }
1356}
1357
1358#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
1359enum ConnectionTableKey {
1360 IP(IpAddr),
1361 Pubkey(Pubkey),
1362}
1363
1364impl ConnectionTableKey {
1365 fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
1366 maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
1367 ConnectionTableKey::Pubkey(pubkey)
1368 })
1369 }
1370}
1371
1372struct ConnectionTable {
1374 table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry>>,
1375 total_size: usize,
1376}
1377
1378impl ConnectionTable {
1381 fn new() -> Self {
1382 Self {
1383 table: IndexMap::default(),
1384 total_size: 0,
1385 }
1386 }
1387
1388 fn prune_oldest(&mut self, max_size: usize) -> usize {
1389 let mut num_pruned = 0;
1390 let key = |(_, connections): &(_, &Vec<_>)| {
1391 connections.iter().map(ConnectionEntry::last_update).min()
1392 };
1393 while self.total_size.saturating_sub(num_pruned) > max_size {
1394 match self.table.values().enumerate().min_by_key(key) {
1395 None => break,
1396 Some((index, connections)) => {
1397 num_pruned += connections.len();
1398 self.table.swap_remove_index(index);
1399 }
1400 }
1401 }
1402 self.total_size = self.total_size.saturating_sub(num_pruned);
1403 num_pruned
1404 }
1405
1406 fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
1411 let num_pruned = std::iter::once(self.table.len())
1412 .filter(|&size| size > 0)
1413 .flat_map(|size| {
1414 let mut rng = thread_rng();
1415 repeat_with(move || rng.gen_range(0..size))
1416 })
1417 .map(|index| {
1418 let connection = self.table[index].first();
1419 let stake = connection.map(|connection: &ConnectionEntry| connection.stake());
1420 (index, stake)
1421 })
1422 .take(sample_size)
1423 .min_by_key(|&(_, stake)| stake)
1424 .filter(|&(_, stake)| stake < Some(threshold_stake))
1425 .and_then(|(index, _)| self.table.swap_remove_index(index))
1426 .map(|(_, connections)| connections.len())
1427 .unwrap_or_default();
1428 self.total_size = self.total_size.saturating_sub(num_pruned);
1429 num_pruned
1430 }
1431
1432 fn try_add_connection(
1433 &mut self,
1434 key: ConnectionTableKey,
1435 port: u16,
1436 client_connection_tracker: ClientConnectionTracker,
1437 connection: Option<Connection>,
1438 peer_type: ConnectionPeerType,
1439 last_update: u64,
1440 max_connections_per_peer: usize,
1441 ) -> Option<(
1442 Arc<AtomicU64>,
1443 CancellationToken,
1444 Arc<ConnectionStreamCounter>,
1445 )> {
1446 let connection_entry = self.table.entry(key).or_default();
1447 let has_connection_capacity = connection_entry
1448 .len()
1449 .checked_add(1)
1450 .map(|c| c <= max_connections_per_peer)
1451 .unwrap_or(false);
1452 if has_connection_capacity {
1453 let cancel = CancellationToken::new();
1454 let last_update = Arc::new(AtomicU64::new(last_update));
1455 let stream_counter = connection_entry
1456 .first()
1457 .map(|entry| entry.stream_counter.clone())
1458 .unwrap_or(Arc::new(ConnectionStreamCounter::new()));
1459 connection_entry.push(ConnectionEntry::new(
1460 cancel.clone(),
1461 peer_type,
1462 last_update.clone(),
1463 port,
1464 client_connection_tracker,
1465 connection,
1466 stream_counter.clone(),
1467 ));
1468 self.total_size += 1;
1469 Some((last_update, cancel, stream_counter))
1470 } else {
1471 if let Some(connection) = connection {
1472 connection.close(
1473 CONNECTION_CLOSE_CODE_TOO_MANY.into(),
1474 CONNECTION_CLOSE_REASON_TOO_MANY,
1475 );
1476 }
1477 None
1478 }
1479 }
1480
1481 fn remove_connection(&mut self, key: ConnectionTableKey, port: u16, stable_id: usize) -> usize {
1483 if let Entry::Occupied(mut e) = self.table.entry(key) {
1484 let e_ref = e.get_mut();
1485 let old_size = e_ref.len();
1486
1487 e_ref.retain(|connection_entry| {
1488 connection_entry.port != port
1494 || connection_entry
1495 .connection
1496 .as_ref()
1497 .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
1498 .is_some()
1499 });
1500 let new_size = e_ref.len();
1501 if e_ref.is_empty() {
1502 e.swap_remove_entry();
1503 }
1504 let connections_removed = old_size.saturating_sub(new_size);
1505 self.total_size = self.total_size.saturating_sub(connections_removed);
1506 connections_removed
1507 } else {
1508 0
1509 }
1510 }
1511}
1512
1513struct EndpointAccept<'a> {
1514 endpoint: usize,
1515 accept: Accept<'a>,
1516}
1517
1518impl Future for EndpointAccept<'_> {
1519 type Output = (Option<quinn::Incoming>, usize);
1520
1521 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
1522 let i = self.endpoint;
1523 unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
1527 .poll(cx)
1528 .map(|r| (r, i))
1529 }
1530}
1531
1532#[cfg(test)]
1533pub mod test {
1534 use {
1535 super::*,
1536 crate::{
1537 nonblocking::{
1538 quic::compute_max_allowed_uni_streams,
1539 testing_utilities::{
1540 check_multiple_streams, get_client_config, make_client_endpoint,
1541 setup_quic_server, SpawnTestServerResult, TestServerConfig,
1542 },
1543 },
1544 quic::DEFAULT_TPU_COALESCE,
1545 },
1546 assert_matches::assert_matches,
1547 async_channel::unbounded as async_unbounded,
1548 crossbeam_channel::{unbounded, Receiver},
1549 quinn::{ApplicationClose, ConnectionError},
1550 solana_keypair::Keypair,
1551 solana_net_utils::bind_to_localhost,
1552 solana_signer::Signer,
1553 std::collections::HashMap,
1554 tokio::time::sleep,
1555 };
1556
1557 pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
1558 let conn1 = make_client_endpoint(&server_address, None).await;
1559 let total = 30;
1560 for i in 0..total {
1561 let mut s1 = conn1.open_uni().await.unwrap();
1562 s1.write_all(&[0u8]).await.unwrap();
1563 s1.finish().unwrap();
1564 info!("done {}", i);
1565 sleep(Duration::from_millis(1000)).await;
1566 }
1567 let mut received = 0;
1568 loop {
1569 if let Ok(_x) = receiver.try_recv() {
1570 received += 1;
1571 info!("got {}", received);
1572 } else {
1573 sleep(Duration::from_millis(500)).await;
1574 }
1575 if received >= total {
1576 break;
1577 }
1578 }
1579 }
1580
1581 pub async fn check_block_multiple_connections(server_address: SocketAddr) {
1582 let conn1 = make_client_endpoint(&server_address, None).await;
1583 let conn2 = make_client_endpoint(&server_address, None).await;
1584 let mut s1 = conn1.open_uni().await.unwrap();
1585 let s2 = conn2.open_uni().await;
1586 if let Ok(mut s2) = s2 {
1587 s1.write_all(&[0u8]).await.unwrap();
1588 s1.finish().unwrap();
1589 let data = vec![1u8; PACKET_DATA_SIZE * 2];
1593 s2.write_all(&data)
1594 .await
1595 .expect_err("shouldn't be able to open 2 connections");
1596 } else {
1597 assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
1601 }
1602 }
1603
1604 pub async fn check_multiple_writes(
1605 receiver: Receiver<PacketBatch>,
1606 server_address: SocketAddr,
1607 client_keypair: Option<&Keypair>,
1608 ) {
1609 let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1610
1611 let num_bytes = PACKET_DATA_SIZE;
1613 let num_expected_packets = 1;
1614 let mut s1 = conn1.open_uni().await.unwrap();
1615 for _ in 0..num_bytes {
1616 s1.write_all(&[0u8]).await.unwrap();
1617 }
1618 s1.finish().unwrap();
1619
1620 let mut all_packets = vec![];
1621 let now = Instant::now();
1622 let mut total_packets = 0;
1623 while now.elapsed().as_secs() < 5 {
1624 if let Ok(packets) = receiver.try_recv() {
1627 total_packets += packets.len();
1628 all_packets.push(packets)
1629 } else {
1630 sleep(Duration::from_secs(1)).await;
1631 }
1632 if total_packets >= num_expected_packets {
1633 break;
1634 }
1635 }
1636 for batch in all_packets {
1637 for p in batch.iter() {
1638 assert_eq!(p.meta().size, num_bytes);
1639 }
1640 }
1641 assert_eq!(total_packets, num_expected_packets);
1642 }
1643
1644 pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
1645 let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1646
1647 if let Ok(mut s1) = conn1.open_uni().await {
1649 for _ in 0..PACKET_DATA_SIZE {
1650 s1.write_all(&[0u8]).await.unwrap_or_default();
1652 }
1653 s1.finish().unwrap_or_default();
1654 s1.stopped().await.unwrap_err();
1655 }
1656 }
1657
1658 #[tokio::test(flavor = "multi_thread")]
1659 async fn test_quic_server_exit() {
1660 let SpawnTestServerResult {
1661 join_handle,
1662 exit,
1663 receiver: _,
1664 server_address: _,
1665 stats: _,
1666 } = setup_quic_server(None, TestServerConfig::default());
1667 exit.store(true, Ordering::Relaxed);
1668 join_handle.await.unwrap();
1669 }
1670
1671 #[tokio::test(flavor = "multi_thread")]
1672 async fn test_quic_timeout() {
1673 solana_logger::setup();
1674 let SpawnTestServerResult {
1675 join_handle,
1676 exit,
1677 receiver,
1678 server_address,
1679 stats: _,
1680 } = setup_quic_server(None, TestServerConfig::default());
1681
1682 check_timeout(receiver, server_address).await;
1683 exit.store(true, Ordering::Relaxed);
1684 join_handle.await.unwrap();
1685 }
1686
1687 #[tokio::test(flavor = "multi_thread")]
1688 async fn test_packet_batcher() {
1689 solana_logger::setup();
1690 let (pkt_batch_sender, pkt_batch_receiver) = unbounded();
1691 let (ptk_sender, pkt_receiver) = async_unbounded();
1692 let exit = Arc::new(AtomicBool::new(false));
1693 let stats = Arc::new(StreamerStats::default());
1694
1695 let handle = tokio::spawn(packet_batch_sender(
1696 pkt_batch_sender,
1697 pkt_receiver,
1698 exit.clone(),
1699 stats,
1700 DEFAULT_TPU_COALESCE,
1701 ));
1702
1703 let num_packets = 1000;
1704
1705 for _i in 0..num_packets {
1706 let mut meta = Meta::default();
1707 let bytes = Bytes::from("Hello world");
1708 let size = bytes.len();
1709 meta.size = size;
1710 let packet_accum = PacketAccumulator {
1711 meta,
1712 chunks: smallvec::smallvec![bytes],
1713 start_time: Instant::now(),
1714 };
1715 ptk_sender.send(packet_accum).await.unwrap();
1716 }
1717 let mut i = 0;
1718 let start = Instant::now();
1719 while i < num_packets && start.elapsed().as_secs() < 2 {
1720 if let Ok(batch) = pkt_batch_receiver.try_recv() {
1721 i += batch.len();
1722 } else {
1723 sleep(Duration::from_millis(1)).await;
1724 }
1725 }
1726 assert_eq!(i, num_packets);
1727 exit.store(true, Ordering::Relaxed);
1728 drop(ptk_sender);
1730 handle.await.unwrap();
1731 }
1732
1733 #[tokio::test(flavor = "multi_thread")]
1734 async fn test_quic_stream_timeout() {
1735 solana_logger::setup();
1736 let SpawnTestServerResult {
1737 join_handle,
1738 exit,
1739 receiver: _,
1740 server_address,
1741 stats,
1742 } = setup_quic_server(None, TestServerConfig::default());
1743
1744 let conn1 = make_client_endpoint(&server_address, None).await;
1745 assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1746 assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1747
1748 let mut s1 = conn1.open_uni().await.unwrap();
1750 s1.write_all(&[0u8]).await.unwrap_or_default();
1751
1752 let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
1754 sleep(sleep_time).await;
1755
1756 assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1758 assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1759
1760 assert!(s1.write_all(&[0u8]).await.is_err());
1763
1764 exit.store(true, Ordering::Relaxed);
1765 join_handle.await.unwrap();
1766 }
1767
1768 #[tokio::test(flavor = "multi_thread")]
1769 async fn test_quic_server_block_multiple_connections() {
1770 solana_logger::setup();
1771 let SpawnTestServerResult {
1772 join_handle,
1773 exit,
1774 receiver: _,
1775 server_address,
1776 stats: _,
1777 } = setup_quic_server(None, TestServerConfig::default());
1778 check_block_multiple_connections(server_address).await;
1779 exit.store(true, Ordering::Relaxed);
1780 join_handle.await.unwrap();
1781 }
1782
1783 #[tokio::test(flavor = "multi_thread")]
1784 async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
1785 solana_logger::setup();
1786
1787 let SpawnTestServerResult {
1788 join_handle,
1789 exit,
1790 receiver: _,
1791 server_address,
1792 stats,
1793 } = setup_quic_server(
1794 None,
1795 TestServerConfig {
1796 max_connections_per_peer: 2,
1797 ..Default::default()
1798 },
1799 );
1800
1801 let client_socket = bind_to_localhost().unwrap();
1802 let mut endpoint = quinn::Endpoint::new(
1803 EndpointConfig::default(),
1804 None,
1805 client_socket,
1806 Arc::new(TokioRuntime),
1807 )
1808 .unwrap();
1809 let default_keypair = Keypair::new();
1810 endpoint.set_default_client_config(get_client_config(&default_keypair));
1811 let conn1 = endpoint
1812 .connect(server_address, "localhost")
1813 .expect("Failed in connecting")
1814 .await
1815 .expect("Failed in waiting");
1816
1817 let conn2 = endpoint
1818 .connect(server_address, "localhost")
1819 .expect("Failed in connecting")
1820 .await
1821 .expect("Failed in waiting");
1822
1823 let mut s1 = conn1.open_uni().await.unwrap();
1824 s1.write_all(&[0u8]).await.unwrap();
1825 s1.finish().unwrap();
1826
1827 let mut s2 = conn2.open_uni().await.unwrap();
1828 conn1.close(
1829 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1830 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1831 );
1832
1833 let start = Instant::now();
1834 while stats.connection_removed.load(Ordering::Relaxed) != 1 {
1835 debug!("First connection not removed yet");
1836 sleep(Duration::from_millis(10)).await;
1837 }
1838 assert!(start.elapsed().as_secs() < 1);
1839
1840 s2.write_all(&[0u8]).await.unwrap();
1841 s2.finish().unwrap();
1842
1843 conn2.close(
1844 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1845 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1846 );
1847
1848 let start = Instant::now();
1849 while stats.connection_removed.load(Ordering::Relaxed) != 2 {
1850 debug!("Second connection not removed yet");
1851 sleep(Duration::from_millis(10)).await;
1852 }
1853 assert!(start.elapsed().as_secs() < 1);
1854
1855 exit.store(true, Ordering::Relaxed);
1856 join_handle.await.unwrap();
1857 }
1858
1859 #[tokio::test(flavor = "multi_thread")]
1860 async fn test_quic_server_multiple_writes() {
1861 solana_logger::setup();
1862 let SpawnTestServerResult {
1863 join_handle,
1864 exit,
1865 receiver,
1866 server_address,
1867 stats: _,
1868 } = setup_quic_server(None, TestServerConfig::default());
1869 check_multiple_writes(receiver, server_address, None).await;
1870 exit.store(true, Ordering::Relaxed);
1871 join_handle.await.unwrap();
1872 }
1873
1874 #[tokio::test(flavor = "multi_thread")]
1875 async fn test_quic_server_staked_connection_removal() {
1876 solana_logger::setup();
1877
1878 let client_keypair = Keypair::new();
1879 let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
1880 let staked_nodes = StakedNodes::new(
1881 Arc::new(stakes),
1882 HashMap::<Pubkey, u64>::default(), );
1884 let SpawnTestServerResult {
1885 join_handle,
1886 exit,
1887 receiver,
1888 server_address,
1889 stats,
1890 } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1891 check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1892 exit.store(true, Ordering::Relaxed);
1893 join_handle.await.unwrap();
1894 sleep(Duration::from_millis(100)).await;
1895 assert_eq!(
1896 stats
1897 .connection_added_from_unstaked_peer
1898 .load(Ordering::Relaxed),
1899 0
1900 );
1901 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1902 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1903 }
1904
1905 #[tokio::test(flavor = "multi_thread")]
1906 async fn test_quic_server_zero_staked_connection_removal() {
1907 solana_logger::setup();
1909
1910 let client_keypair = Keypair::new();
1911 let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
1912 let staked_nodes = StakedNodes::new(
1913 Arc::new(stakes),
1914 HashMap::<Pubkey, u64>::default(), );
1916 let SpawnTestServerResult {
1917 join_handle,
1918 exit,
1919 receiver,
1920 server_address,
1921 stats,
1922 } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1923 check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1924 exit.store(true, Ordering::Relaxed);
1925 join_handle.await.unwrap();
1926 sleep(Duration::from_millis(100)).await;
1927 assert_eq!(
1928 stats
1929 .connection_added_from_staked_peer
1930 .load(Ordering::Relaxed),
1931 0
1932 );
1933 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1934 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1935 }
1936
1937 #[tokio::test(flavor = "multi_thread")]
1938 async fn test_quic_server_unstaked_connection_removal() {
1939 solana_logger::setup();
1940 let SpawnTestServerResult {
1941 join_handle,
1942 exit,
1943 receiver,
1944 server_address,
1945 stats,
1946 } = setup_quic_server(None, TestServerConfig::default());
1947 check_multiple_writes(receiver, server_address, None).await;
1948 exit.store(true, Ordering::Relaxed);
1949 join_handle.await.unwrap();
1950 sleep(Duration::from_millis(100)).await;
1951 assert_eq!(
1952 stats
1953 .connection_added_from_staked_peer
1954 .load(Ordering::Relaxed),
1955 0
1956 );
1957 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1958 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1959 }
1960
1961 #[tokio::test(flavor = "multi_thread")]
1962 async fn test_quic_server_unstaked_node_connect_failure() {
1963 solana_logger::setup();
1964 let s = bind_to_localhost().unwrap();
1965 let exit = Arc::new(AtomicBool::new(false));
1966 let (sender, _) = unbounded();
1967 let keypair = Keypair::new();
1968 let server_address = s.local_addr().unwrap();
1969 let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
1970 let SpawnNonBlockingServerResult {
1971 endpoints: _,
1972 stats: _,
1973 thread: t,
1974 max_concurrent_connections: _,
1975 } = spawn_server(
1976 "quic_streamer_test",
1977 s,
1978 &keypair,
1979 sender,
1980 exit.clone(),
1981 staked_nodes,
1982 QuicServerParams {
1983 max_unstaked_connections: 0, coalesce_channel_size: 100_000, ..QuicServerParams::default()
1986 },
1987 )
1988 .unwrap();
1989
1990 check_unstaked_node_connect_failure(server_address).await;
1991 exit.store(true, Ordering::Relaxed);
1992 t.await.unwrap();
1993 }
1994
1995 #[tokio::test(flavor = "multi_thread")]
1996 async fn test_quic_server_multiple_streams() {
1997 solana_logger::setup();
1998 let s = bind_to_localhost().unwrap();
1999 let exit = Arc::new(AtomicBool::new(false));
2000 let (sender, receiver) = unbounded();
2001 let keypair = Keypair::new();
2002 let server_address = s.local_addr().unwrap();
2003 let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
2004 let SpawnNonBlockingServerResult {
2005 endpoints: _,
2006 stats,
2007 thread: t,
2008 max_concurrent_connections: _,
2009 } = spawn_server(
2010 "quic_streamer_test",
2011 s,
2012 &keypair,
2013 sender,
2014 exit.clone(),
2015 staked_nodes,
2016 QuicServerParams {
2017 max_connections_per_peer: 2,
2018 coalesce_channel_size: 100_000, ..QuicServerParams::default()
2020 },
2021 )
2022 .unwrap();
2023
2024 check_multiple_streams(receiver, server_address, None).await;
2025 assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
2026 assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
2027 assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
2028 assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2029 exit.store(true, Ordering::Relaxed);
2030 t.await.unwrap();
2031 assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
2032 assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2033 }
2034
2035 #[test]
2036 fn test_prune_table_with_ip() {
2037 use std::net::Ipv4Addr;
2038 solana_logger::setup();
2039 let mut table = ConnectionTable::new();
2040 let mut num_entries = 5;
2041 let max_connections_per_peer = 10;
2042 let sockets: Vec<_> = (0..num_entries)
2043 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2044 .collect();
2045 let stats = Arc::new(StreamerStats::default());
2046 for (i, socket) in sockets.iter().enumerate() {
2047 table
2048 .try_add_connection(
2049 ConnectionTableKey::IP(socket.ip()),
2050 socket.port(),
2051 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2052 None,
2053 ConnectionPeerType::Unstaked,
2054 i as u64,
2055 max_connections_per_peer,
2056 )
2057 .unwrap();
2058 }
2059 num_entries += 1;
2060 table
2061 .try_add_connection(
2062 ConnectionTableKey::IP(sockets[0].ip()),
2063 sockets[0].port(),
2064 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2065 None,
2066 ConnectionPeerType::Unstaked,
2067 5,
2068 max_connections_per_peer,
2069 )
2070 .unwrap();
2071
2072 let new_size = 3;
2073 let pruned = table.prune_oldest(new_size);
2074 assert_eq!(pruned, num_entries as usize - new_size);
2075 for v in table.table.values() {
2076 for x in v {
2077 assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
2078 }
2079 }
2080 assert_eq!(table.table.len(), new_size);
2081 assert_eq!(table.total_size, new_size);
2082 for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
2083 table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2084 }
2085 assert_eq!(table.total_size, 0);
2086 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2087 }
2088
2089 #[test]
2090 fn test_prune_table_with_unique_pubkeys() {
2091 solana_logger::setup();
2092 let mut table = ConnectionTable::new();
2093
2094 let num_entries = 15;
2097 let max_connections_per_peer = 10;
2098 let stats = Arc::new(StreamerStats::default());
2099
2100 let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
2101 for (i, pubkey) in pubkeys.iter().enumerate() {
2102 table
2103 .try_add_connection(
2104 ConnectionTableKey::Pubkey(*pubkey),
2105 0,
2106 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2107 None,
2108 ConnectionPeerType::Unstaked,
2109 i as u64,
2110 max_connections_per_peer,
2111 )
2112 .unwrap();
2113 }
2114
2115 let new_size = 3;
2116 let pruned = table.prune_oldest(new_size);
2117 assert_eq!(pruned, num_entries as usize - new_size);
2118 assert_eq!(table.table.len(), new_size);
2119 assert_eq!(table.total_size, new_size);
2120 for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
2121 table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
2122 }
2123 assert_eq!(table.total_size, 0);
2124 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2125 }
2126
2127 #[test]
2128 fn test_prune_table_with_non_unique_pubkeys() {
2129 solana_logger::setup();
2130 let mut table = ConnectionTable::new();
2131
2132 let max_connections_per_peer = 10;
2133 let pubkey = Pubkey::new_unique();
2134 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2135
2136 (0..max_connections_per_peer).for_each(|i| {
2137 table
2138 .try_add_connection(
2139 ConnectionTableKey::Pubkey(pubkey),
2140 0,
2141 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2142 None,
2143 ConnectionPeerType::Unstaked,
2144 i as u64,
2145 max_connections_per_peer,
2146 )
2147 .unwrap();
2148 });
2149
2150 assert!(table
2153 .try_add_connection(
2154 ConnectionTableKey::Pubkey(pubkey),
2155 0,
2156 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2157 None,
2158 ConnectionPeerType::Unstaked,
2159 10,
2160 max_connections_per_peer,
2161 )
2162 .is_none());
2163
2164 let num_entries = max_connections_per_peer + 1;
2166 let pubkey2 = Pubkey::new_unique();
2167 assert!(table
2168 .try_add_connection(
2169 ConnectionTableKey::Pubkey(pubkey2),
2170 0,
2171 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2172 None,
2173 ConnectionPeerType::Unstaked,
2174 10,
2175 max_connections_per_peer,
2176 )
2177 .is_some());
2178
2179 assert_eq!(table.total_size, num_entries);
2180
2181 let new_max_size = 3;
2182 let pruned = table.prune_oldest(new_max_size);
2183 assert!(pruned >= num_entries - new_max_size);
2184 assert!(table.table.len() <= new_max_size);
2185 assert!(table.total_size <= new_max_size);
2186
2187 table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
2188 assert_eq!(table.total_size, 0);
2189 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2190 }
2191
2192 #[test]
2193 fn test_prune_table_random() {
2194 use std::net::Ipv4Addr;
2195 solana_logger::setup();
2196 let mut table = ConnectionTable::new();
2197 let num_entries = 5;
2198 let max_connections_per_peer = 10;
2199 let sockets: Vec<_> = (0..num_entries)
2200 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2201 .collect();
2202 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2203
2204 for (i, socket) in sockets.iter().enumerate() {
2205 table
2206 .try_add_connection(
2207 ConnectionTableKey::IP(socket.ip()),
2208 socket.port(),
2209 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2210 None,
2211 ConnectionPeerType::Staked((i + 1) as u64),
2212 i as u64,
2213 max_connections_per_peer,
2214 )
2215 .unwrap();
2216 }
2217
2218 let pruned = table.prune_random(2, 0);
2221 assert_eq!(pruned, 0);
2222
2223 let pruned = table.prune_random(
2226 2, num_entries as u64 + 1, );
2229 assert_eq!(pruned, 1);
2230 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
2232 }
2233
2234 #[test]
2235 fn test_remove_connections() {
2236 use std::net::Ipv4Addr;
2237 solana_logger::setup();
2238 let mut table = ConnectionTable::new();
2239 let num_ips = 5;
2240 let max_connections_per_peer = 10;
2241 let mut sockets: Vec<_> = (0..num_ips)
2242 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2243 .collect();
2244 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2245
2246 for (i, socket) in sockets.iter().enumerate() {
2247 table
2248 .try_add_connection(
2249 ConnectionTableKey::IP(socket.ip()),
2250 socket.port(),
2251 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2252 None,
2253 ConnectionPeerType::Unstaked,
2254 (i * 2) as u64,
2255 max_connections_per_peer,
2256 )
2257 .unwrap();
2258
2259 table
2260 .try_add_connection(
2261 ConnectionTableKey::IP(socket.ip()),
2262 socket.port(),
2263 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2264 None,
2265 ConnectionPeerType::Unstaked,
2266 (i * 2 + 1) as u64,
2267 max_connections_per_peer,
2268 )
2269 .unwrap();
2270 }
2271
2272 let single_connection_addr =
2273 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
2274 table
2275 .try_add_connection(
2276 ConnectionTableKey::IP(single_connection_addr.ip()),
2277 single_connection_addr.port(),
2278 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2279 None,
2280 ConnectionPeerType::Unstaked,
2281 (num_ips * 2) as u64,
2282 max_connections_per_peer,
2283 )
2284 .unwrap();
2285
2286 let zero_connection_addr =
2287 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
2288
2289 sockets.push(single_connection_addr);
2290 sockets.push(zero_connection_addr);
2291
2292 for socket in sockets.iter() {
2293 table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2294 }
2295 assert_eq!(table.total_size, 0);
2296 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2297 }
2298
2299 #[test]
2300
2301 fn test_max_allowed_uni_streams() {
2302 assert_eq!(
2303 compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0),
2304 QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2305 );
2306 assert_eq!(
2307 compute_max_allowed_uni_streams(ConnectionPeerType::Staked(10), 0),
2308 QUIC_MIN_STAKED_CONCURRENT_STREAMS
2309 );
2310 let delta =
2311 (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
2312 assert_eq!(
2313 compute_max_allowed_uni_streams(ConnectionPeerType::Staked(1000), 10000),
2314 QUIC_MAX_STAKED_CONCURRENT_STREAMS,
2315 );
2316 assert_eq!(
2317 compute_max_allowed_uni_streams(ConnectionPeerType::Staked(100), 10000),
2318 ((delta / (100_f64)) as usize + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
2319 .min(QUIC_MAX_STAKED_CONCURRENT_STREAMS)
2320 );
2321 assert_eq!(
2322 compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10000),
2323 QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2324 );
2325 }
2326
2327 #[test]
2328 fn test_cacluate_receive_window_ratio_for_staked_node() {
2329 let mut max_stake = 10000;
2330 let mut min_stake = 0;
2331 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, min_stake);
2332 assert_eq!(ratio, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO);
2333
2334 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2335 let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
2336 assert_eq!(ratio, max_ratio);
2337
2338 let ratio =
2339 compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake / 2);
2340 let average_ratio =
2341 (QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO + QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO) / 2;
2342 assert_eq!(ratio, average_ratio);
2343
2344 max_stake = 10000;
2345 min_stake = 10000;
2346 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2347 assert_eq!(ratio, max_ratio);
2348
2349 max_stake = 0;
2350 min_stake = 0;
2351 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2352 assert_eq!(ratio, max_ratio);
2353
2354 max_stake = 1000;
2355 min_stake = 10;
2356 let ratio =
2357 compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake + 10);
2358 assert_eq!(ratio, max_ratio);
2359 }
2360
2361 #[tokio::test(flavor = "multi_thread")]
2362 async fn test_throttling_check_no_packet_drop() {
2363 solana_logger::setup_with_default_filter();
2364
2365 let SpawnTestServerResult {
2366 join_handle,
2367 exit,
2368 receiver,
2369 server_address,
2370 stats,
2371 } = setup_quic_server(None, TestServerConfig::default());
2372
2373 let client_connection = make_client_endpoint(&server_address, None).await;
2374
2375 let expected_num_txs = 100;
2377 let start_time = tokio::time::Instant::now();
2378 for i in 0..expected_num_txs {
2379 let mut send_stream = client_connection.open_uni().await.unwrap();
2380 let data = format!("{i}").into_bytes();
2381 send_stream.write_all(&data).await.unwrap();
2382 send_stream.finish().unwrap();
2383 }
2384 let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
2385 info!("Elapsed sending: {elapsed_sending}");
2386
2387 let start_time = tokio::time::Instant::now();
2389 let mut num_txs_received = 0;
2390 while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
2391 if let Ok(packets) = receiver.try_recv() {
2392 num_txs_received += packets.len();
2393 } else {
2394 sleep(Duration::from_millis(100)).await;
2395 }
2396 }
2397 assert_eq!(expected_num_txs, num_txs_received);
2398
2399 exit.store(true, Ordering::Relaxed);
2401 join_handle.await.unwrap();
2402
2403 assert_eq!(
2404 stats.total_new_streams.load(Ordering::Relaxed),
2405 expected_num_txs
2406 );
2407 assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
2408 }
2409
2410 #[test]
2411 fn test_client_connection_tracker() {
2412 let stats = Arc::new(StreamerStats::default());
2413 let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
2414 assert!(tracker_1.is_ok());
2415 assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
2416 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
2417 drop(tracker_1);
2419 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2420 }
2421
2422 #[tokio::test(flavor = "multi_thread")]
2423 async fn test_client_connection_close_invalid_stream() {
2424 let SpawnTestServerResult {
2425 join_handle,
2426 server_address,
2427 stats,
2428 exit,
2429 ..
2430 } = setup_quic_server(None, TestServerConfig::default());
2431
2432 let client_connection = make_client_endpoint(&server_address, None).await;
2433
2434 let mut send_stream = client_connection.open_uni().await.unwrap();
2435 send_stream
2436 .write_all(&[42; PACKET_DATA_SIZE + 1])
2437 .await
2438 .unwrap();
2439 match client_connection.closed().await {
2440 ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
2441 assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
2442 assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
2443 }
2444 _ => panic!("unexpected close"),
2445 }
2446 assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
2447 exit.store(true, Ordering::Relaxed);
2448 join_handle.await.unwrap();
2449 }
2450}