1use {
2 crate::{
3 nonblocking::{
4 connection_rate_limiter::{ConnectionRateLimiter, TotalConnectionRateLimiter},
5 stream_throttle::{
6 ConnectionStreamCounter, StakedStreamLoadEMA, STREAM_THROTTLING_INTERVAL,
7 STREAM_THROTTLING_INTERVAL_MS,
8 },
9 },
10 quic::{configure_server, QuicServerError, StreamerStats},
11 streamer::StakedNodes,
12 tls_certificates::get_pubkey_from_tls_certificate,
13 },
14 async_channel::{
15 unbounded as async_unbounded, Receiver as AsyncReceiver, Sender as AsyncSender,
16 },
17 bytes::Bytes,
18 crossbeam_channel::Sender,
19 futures::{stream::FuturesUnordered, Future, StreamExt as _},
20 indexmap::map::{Entry, IndexMap},
21 percentage::Percentage,
22 quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime, VarInt},
23 quinn_proto::VarIntBoundsExceeded,
24 rand::{thread_rng, Rng},
25 smallvec::SmallVec,
26 solana_measure::measure::Measure,
27 solana_perf::packet::{PacketBatch, PACKETS_PER_BATCH},
28 solana_sdk::{
29 packet::{Meta, PACKET_DATA_SIZE},
30 pubkey::Pubkey,
31 quic::{
32 QUIC_CONNECTION_HANDSHAKE_TIMEOUT, QUIC_MAX_STAKED_CONCURRENT_STREAMS,
33 QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
34 QUIC_MIN_STAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO,
35 QUIC_TOTAL_STAKED_CONCURRENT_STREAMS, QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO,
36 },
37 signature::{Keypair, Signature},
38 timing,
39 },
40 solana_transaction_metrics_tracker::signature_if_should_track_packet,
41 std::{
42 array,
43 fmt,
44 iter::repeat_with,
45 net::{IpAddr, SocketAddr, UdpSocket},
46 pin::Pin,
47 sync::{
49 atomic::{AtomicBool, AtomicU64, Ordering},
50 Arc, RwLock,
51 },
52 task::Poll,
53 time::{Duration, Instant},
54 },
55 tokio::{
56 select,
65 sync::{Mutex, MutexGuard},
66 task::JoinHandle,
67 time::{sleep, timeout},
68 },
69 tokio_util::sync::CancellationToken,
70};
71
72pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
73
74pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
75
76const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
77const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
78
79const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
80const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
81
82const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
83const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] = b"exceed_max_stream_count";
84
85const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
86const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
87
88const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
89const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
90
91pub const DEFAULT_MAX_STREAMS_PER_MS: u64 = 250;
93
94pub const DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE: u64 = 8;
98
99const TOTAL_CONNECTIONS_PER_SECOND: u64 = 2500;
103
104const CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD: usize = 100_000;
108
109#[derive(Clone)]
115struct PacketAccumulator {
116 pub meta: Meta,
117 pub chunks: SmallVec<[Bytes; 2]>,
118 pub start_time: Instant,
119}
120
121impl PacketAccumulator {
122 fn new(meta: Meta) -> Self {
123 Self {
124 meta,
125 chunks: SmallVec::default(),
126 start_time: Instant::now(),
127 }
128 }
129}
130
131#[derive(Copy, Clone, Debug)]
132pub enum ConnectionPeerType {
133 Unstaked,
134 Staked(u64),
135}
136
137impl ConnectionPeerType {
138 pub(crate) fn is_staked(&self) -> bool {
139 matches!(self, ConnectionPeerType::Staked(_))
140 }
141}
142
143pub struct SpawnNonBlockingServerResult {
144 pub endpoints: Vec<Endpoint>,
145 pub stats: Arc<StreamerStats>,
146 pub thread: JoinHandle<()>,
147 pub max_concurrent_connections: usize,
148}
149
150#[allow(clippy::too_many_arguments)]
151pub fn spawn_server(
152 name: &'static str,
153 sock: UdpSocket,
154 keypair: &Keypair,
155 packet_sender: Sender<PacketBatch>,
156 exit: Arc<AtomicBool>,
157 max_connections_per_peer: usize,
158 staked_nodes: Arc<RwLock<StakedNodes>>,
159 max_staked_connections: usize,
160 max_unstaked_connections: usize,
161 max_streams_per_ms: u64,
162 max_connections_per_ipaddr_per_min: u64,
163 wait_for_chunk_timeout: Duration,
164 coalesce: Duration,
165) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
166 spawn_server_multi(
167 name,
168 vec![sock],
169 keypair,
170 packet_sender,
171 exit,
172 max_connections_per_peer,
173 staked_nodes,
174 max_staked_connections,
175 max_unstaked_connections,
176 max_streams_per_ms,
177 max_connections_per_ipaddr_per_min,
178 wait_for_chunk_timeout,
179 coalesce,
180 )
181}
182
183#[allow(clippy::too_many_arguments, clippy::type_complexity)]
184pub fn spawn_server_multi(
185 name: &'static str,
186 sockets: Vec<UdpSocket>,
187 keypair: &Keypair,
188 packet_sender: Sender<PacketBatch>,
189 exit: Arc<AtomicBool>,
190 max_connections_per_peer: usize,
191 staked_nodes: Arc<RwLock<StakedNodes>>,
192 max_staked_connections: usize,
193 max_unstaked_connections: usize,
194 max_streams_per_ms: u64,
195 max_connections_per_ipaddr_per_min: u64,
196 wait_for_chunk_timeout: Duration,
197 coalesce: Duration,
198) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
199 info!("Start {name} quic server on {sockets:?}");
200 let concurrent_connections = max_staked_connections + max_unstaked_connections;
201 let max_concurrent_connections = concurrent_connections + concurrent_connections / 4;
202 let (config, _) = configure_server(keypair)?;
203
204 let endpoints = sockets
205 .into_iter()
206 .map(|sock| {
207 Endpoint::new(
208 EndpointConfig::default(),
209 Some(config.clone()),
210 sock,
211 Arc::new(TokioRuntime),
212 )
213 .map_err(QuicServerError::EndpointFailed)
214 })
215 .collect::<Result<Vec<_>, _>>()?;
216 let stats = Arc::<StreamerStats>::default();
217 let handle = tokio::spawn(run_server(
218 name,
219 endpoints.clone(),
220 packet_sender,
221 exit,
222 max_connections_per_peer,
223 staked_nodes,
224 max_staked_connections,
225 max_unstaked_connections,
226 max_streams_per_ms,
227 max_connections_per_ipaddr_per_min,
228 stats.clone(),
229 wait_for_chunk_timeout,
230 coalesce,
231 max_concurrent_connections,
232 ));
233 Ok(SpawnNonBlockingServerResult {
234 endpoints,
235 stats,
236 thread: handle,
237 max_concurrent_connections,
238 })
239}
240
241struct ClientConnectionTracker {
247 stats: Arc<StreamerStats>,
248}
249
250impl fmt::Debug for ClientConnectionTracker {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 f.debug_struct("StreamerClientConnection")
254 .field(
255 "open_connections:",
256 &self.stats.open_connections.load(Ordering::Relaxed),
257 )
258 .finish()
259 }
260}
261
262impl Drop for ClientConnectionTracker {
263 fn drop(&mut self) {
265 self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
266 }
267}
268
269impl ClientConnectionTracker {
270 fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
273 let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
274 if open_connections >= max_concurrent_connections {
275 stats.open_connections.fetch_sub(1, Ordering::Relaxed);
276 debug!(
277 "There are too many concurrent connections opened already: open: {}, max: {}",
278 open_connections, max_concurrent_connections
279 );
280 return Err(());
281 }
282
283 Ok(Self { stats })
284 }
285}
286
287#[allow(clippy::too_many_arguments)]
288async fn run_server(
289 name: &'static str,
290 endpoints: Vec<Endpoint>,
291 packet_sender: Sender<PacketBatch>,
292 exit: Arc<AtomicBool>,
293 max_connections_per_peer: usize,
294 staked_nodes: Arc<RwLock<StakedNodes>>,
295 max_staked_connections: usize,
296 max_unstaked_connections: usize,
297 max_streams_per_ms: u64,
298 max_connections_per_ipaddr_per_min: u64,
299 stats: Arc<StreamerStats>,
300 wait_for_chunk_timeout: Duration,
301 coalesce: Duration,
302 max_concurrent_connections: usize,
303) {
304 let rate_limiter = ConnectionRateLimiter::new(max_connections_per_ipaddr_per_min);
305 let overall_connection_rate_limiter =
306 TotalConnectionRateLimiter::new(TOTAL_CONNECTIONS_PER_SECOND);
307
308 const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
309 debug!("spawn quic server");
310 let mut last_datapoint = Instant::now();
311 let unstaked_connection_table: Arc<Mutex<ConnectionTable>> =
312 Arc::new(Mutex::new(ConnectionTable::new()));
313 let stream_load_ema = Arc::new(StakedStreamLoadEMA::new(
314 stats.clone(),
315 max_unstaked_connections,
316 max_streams_per_ms,
317 ));
318 stats
319 .quic_endpoints_count
320 .store(endpoints.len(), Ordering::Relaxed);
321 let staked_connection_table: Arc<Mutex<ConnectionTable>> =
322 Arc::new(Mutex::new(ConnectionTable::new()));
323 let (sender, receiver) = async_unbounded();
324 tokio::spawn(packet_batch_sender(
325 packet_sender,
326 receiver,
327 exit.clone(),
328 stats.clone(),
329 coalesce,
330 ));
331
332 let mut accepts = endpoints
333 .iter()
334 .enumerate()
335 .map(|(i, incoming)| {
336 Box::pin(EndpointAccept {
337 accept: incoming.accept(),
338 endpoint: i,
339 })
340 })
341 .collect::<FuturesUnordered<_>>();
342
343 while !exit.load(Ordering::Relaxed) {
344 let timeout_connection = select! {
345 ready = accepts.next() => {
346 if let Some((connecting, i)) = ready {
347 accepts.push(
348 Box::pin(EndpointAccept {
349 accept: endpoints[i].accept(),
350 endpoint: i,
351 }
352 ));
353 Ok(connecting)
354 } else {
355 continue
357 }
358 }
359 _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
360 Err(())
361 }
362 };
363
364 if last_datapoint.elapsed().as_secs() >= 5 {
365 stats.report(name);
366 last_datapoint = Instant::now();
367 }
368
369 if let Ok(Some(incoming)) = timeout_connection {
370 stats
371 .total_incoming_connection_attempts
372 .fetch_add(1, Ordering::Relaxed);
373
374 let remote_address = incoming.remote_address();
375
376 if !overall_connection_rate_limiter.is_allowed() {
378 debug!(
379 "Reject connection from {:?} -- total rate limiting exceeded",
380 remote_address.ip()
381 );
382 stats
383 .connection_rate_limited_across_all
384 .fetch_add(1, Ordering::Relaxed);
385 incoming.ignore();
386 continue;
387 }
388
389 if rate_limiter.len() > CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD {
390 rate_limiter.retain_recent();
391 }
392 stats
393 .connection_rate_limiter_length
394 .store(rate_limiter.len(), Ordering::Relaxed);
395 debug!("Got a connection {remote_address:?}");
396 if !rate_limiter.is_allowed(&remote_address.ip()) {
397 debug!(
398 "Reject connection from {:?} -- rate limiting exceeded",
399 remote_address
400 );
401 stats
402 .connection_rate_limited_per_ipaddr
403 .fetch_add(1, Ordering::Relaxed);
404 incoming.ignore();
405 continue;
406 }
407
408 let Ok(client_connection_tracker) =
409 ClientConnectionTracker::new(stats.clone(), max_concurrent_connections)
410 else {
411 stats
412 .refused_connections_too_many_open_connections
413 .fetch_add(1, Ordering::Relaxed);
414 incoming.refuse();
415 continue;
416 };
417
418 stats
419 .outstanding_incoming_connection_attempts
420 .fetch_add(1, Ordering::Relaxed);
421 let connecting = incoming.accept();
422 match connecting {
423 Ok(connecting) => {
424 tokio::spawn(setup_connection(
425 connecting,
426 client_connection_tracker,
427 unstaked_connection_table.clone(),
428 staked_connection_table.clone(),
429 sender.clone(),
430 max_connections_per_peer,
431 staked_nodes.clone(),
432 max_staked_connections,
433 max_unstaked_connections,
434 max_streams_per_ms,
435 stats.clone(),
436 wait_for_chunk_timeout,
437 stream_load_ema.clone(),
438 ));
439 }
440 Err(err) => {
441 debug!("Incoming::accept(): error {:?}", err);
442 }
443 }
444 } else {
445 debug!("accept(): Timed out waiting for connection");
446 }
447 }
448}
449
450fn prune_unstaked_connection_table(
451 unstaked_connection_table: &mut ConnectionTable,
452 max_unstaked_connections: usize,
453 stats: Arc<StreamerStats>,
454) {
455 if unstaked_connection_table.total_size >= max_unstaked_connections {
456 const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
457 let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
458
459 let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
460 let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
461 stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
462 }
463}
464
465pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
466 connection
468 .peer_identity()?
469 .downcast::<Vec<rustls::pki_types::CertificateDer>>()
470 .ok()
471 .filter(|certs| certs.len() == 1)?
472 .first()
473 .and_then(get_pubkey_from_tls_certificate)
474}
475
476fn get_connection_stake(
477 connection: &Connection,
478 staked_nodes: &RwLock<StakedNodes>,
479) -> Option<(Pubkey, u64, u64, u64, u64)> {
480 let pubkey = get_remote_pubkey(connection)?;
481 debug!("Peer public key is {pubkey:?}");
482 let staked_nodes = staked_nodes.read().unwrap();
483 Some((
484 pubkey,
485 staked_nodes.get_node_stake(&pubkey)?,
486 staked_nodes.total_stake(),
487 staked_nodes.max_stake(),
488 staked_nodes.min_stake(),
489 ))
490}
491
492pub fn compute_max_allowed_uni_streams(peer_type: ConnectionPeerType, total_stake: u64) -> usize {
493 match peer_type {
494 ConnectionPeerType::Staked(peer_stake) => {
495 if total_stake == 0 || peer_stake > total_stake {
497 warn!(
498 "Invalid stake values: peer_stake: {:?}, total_stake: {:?}",
499 peer_stake, total_stake,
500 );
501
502 QUIC_MIN_STAKED_CONCURRENT_STREAMS
503 } else {
504 let delta = (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS
505 - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
506
507 (((peer_stake as f64 / total_stake as f64) * delta) as usize
508 + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
509 .clamp(
510 QUIC_MIN_STAKED_CONCURRENT_STREAMS,
511 QUIC_MAX_STAKED_CONCURRENT_STREAMS,
512 )
513 }
514 }
515 ConnectionPeerType::Unstaked => QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
516 }
517}
518
519enum ConnectionHandlerError {
520 ConnectionAddError,
521 MaxStreamError,
522}
523
524#[derive(Clone)]
525struct NewConnectionHandlerParams {
526 packet_sender: AsyncSender<PacketAccumulator>,
532 remote_pubkey: Option<Pubkey>,
533 peer_type: ConnectionPeerType,
534 total_stake: u64,
535 max_connections_per_peer: usize,
536 stats: Arc<StreamerStats>,
537 max_stake: u64,
538 min_stake: u64,
539}
540
541impl NewConnectionHandlerParams {
542 fn new_unstaked(
543 packet_sender: AsyncSender<PacketAccumulator>,
544 max_connections_per_peer: usize,
545 stats: Arc<StreamerStats>,
546 ) -> NewConnectionHandlerParams {
547 NewConnectionHandlerParams {
548 packet_sender,
549 remote_pubkey: None,
550 peer_type: ConnectionPeerType::Unstaked,
551 total_stake: 0,
552 max_connections_per_peer,
553 stats,
554 max_stake: 0,
555 min_stake: 0,
556 }
557 }
558}
559
560fn handle_and_cache_new_connection(
561 client_connection_tracker: ClientConnectionTracker,
562 connection: Connection,
563 mut connection_table_l: MutexGuard<ConnectionTable>,
564 connection_table: Arc<Mutex<ConnectionTable>>,
565 params: &NewConnectionHandlerParams,
566 wait_for_chunk_timeout: Duration,
567 stream_load_ema: Arc<StakedStreamLoadEMA>,
568) -> Result<(), ConnectionHandlerError> {
569 if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
570 params.peer_type,
571 params.total_stake,
572 ) as u64)
573 {
574 let remote_addr = connection.remote_address();
575 let receive_window =
576 compute_recieve_window(params.max_stake, params.min_stake, params.peer_type);
577
578 debug!(
579 "Peer type {:?}, total stake {}, max streams {} receive_window {:?} from peer {}",
580 params.peer_type,
581 params.total_stake,
582 max_uni_streams.into_inner(),
583 receive_window,
584 remote_addr,
585 );
586
587 if let Some((last_update, cancel_connection, stream_counter)) = connection_table_l
588 .try_add_connection(
589 ConnectionTableKey::new(remote_addr.ip(), params.remote_pubkey),
590 remote_addr.port(),
591 client_connection_tracker,
592 Some(connection.clone()),
593 params.peer_type,
594 timing::timestamp(),
595 params.max_connections_per_peer,
596 )
597 {
598 drop(connection_table_l);
599
600 if let Ok(receive_window) = receive_window {
601 connection.set_receive_window(receive_window);
602 }
603 connection.set_max_concurrent_uni_streams(max_uni_streams);
604
605 tokio::spawn(handle_connection(
606 connection,
607 remote_addr,
608 last_update,
609 connection_table,
610 cancel_connection,
611 params.clone(),
612 wait_for_chunk_timeout,
613 stream_load_ema,
614 stream_counter,
615 ));
616 Ok(())
617 } else {
618 params
619 .stats
620 .connection_add_failed
621 .fetch_add(1, Ordering::Relaxed);
622 Err(ConnectionHandlerError::ConnectionAddError)
623 }
624 } else {
625 connection.close(
626 CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT.into(),
627 CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT,
628 );
629 params
630 .stats
631 .connection_add_failed_invalid_stream_count
632 .fetch_add(1, Ordering::Relaxed);
633 Err(ConnectionHandlerError::MaxStreamError)
634 }
635}
636
637async fn prune_unstaked_connections_and_add_new_connection(
638 client_connection_tracker: ClientConnectionTracker,
639 connection: Connection,
640 connection_table: Arc<Mutex<ConnectionTable>>,
641 max_connections: usize,
642 params: &NewConnectionHandlerParams,
643 wait_for_chunk_timeout: Duration,
644 stream_load_ema: Arc<StakedStreamLoadEMA>,
645) -> Result<(), ConnectionHandlerError> {
646 let stats = params.stats.clone();
647 if max_connections > 0 {
648 let connection_table_clone = connection_table.clone();
649 let mut connection_table = connection_table.lock().await;
650 prune_unstaked_connection_table(&mut connection_table, max_connections, stats);
651 handle_and_cache_new_connection(
652 client_connection_tracker,
653 connection,
654 connection_table,
655 connection_table_clone,
656 params,
657 wait_for_chunk_timeout,
658 stream_load_ema,
659 )
660 } else {
661 connection.close(
662 CONNECTION_CLOSE_CODE_DISALLOWED.into(),
663 CONNECTION_CLOSE_REASON_DISALLOWED,
664 );
665 Err(ConnectionHandlerError::ConnectionAddError)
666 }
667}
668
669fn compute_receive_window_ratio_for_staked_node(max_stake: u64, min_stake: u64, stake: u64) -> u64 {
671 if stake > max_stake {
680 return QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
681 }
682
683 let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
684 let min_ratio = QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO;
685 if max_stake > min_stake {
686 let a = (max_ratio - min_ratio) as f64 / (max_stake - min_stake) as f64;
687 let b = max_ratio as f64 - ((max_stake as f64) * a);
688 let ratio = (a * stake as f64) + b;
689 ratio.round() as u64
690 } else {
691 QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO
692 }
693}
694
695fn compute_recieve_window(
696 max_stake: u64,
697 min_stake: u64,
698 peer_type: ConnectionPeerType,
699) -> Result<VarInt, VarIntBoundsExceeded> {
700 match peer_type {
701 ConnectionPeerType::Unstaked => {
702 VarInt::from_u64(PACKET_DATA_SIZE as u64 * QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO)
703 }
704 ConnectionPeerType::Staked(peer_stake) => {
705 let ratio =
706 compute_receive_window_ratio_for_staked_node(max_stake, min_stake, peer_stake);
707 VarInt::from_u64(PACKET_DATA_SIZE as u64 * ratio)
708 }
709 }
710}
711
712#[allow(clippy::too_many_arguments)]
713async fn setup_connection(
714 connecting: Connecting,
715 client_connection_tracker: ClientConnectionTracker,
716 unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
717 staked_connection_table: Arc<Mutex<ConnectionTable>>,
718 packet_sender: AsyncSender<PacketAccumulator>,
719 max_connections_per_peer: usize,
720 staked_nodes: Arc<RwLock<StakedNodes>>,
721 max_staked_connections: usize,
722 max_unstaked_connections: usize,
723 max_streams_per_ms: u64,
724 stats: Arc<StreamerStats>,
725 wait_for_chunk_timeout: Duration,
726 stream_load_ema: Arc<StakedStreamLoadEMA>,
727) {
728 const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
729 let from = connecting.remote_address();
730 let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
731 stats
732 .outstanding_incoming_connection_attempts
733 .fetch_sub(1, Ordering::Relaxed);
734 if let Ok(connecting_result) = res {
735 match connecting_result {
736 Ok(new_connection) => {
737 stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
738
739 let params = get_connection_stake(&new_connection, &staked_nodes).map_or(
740 NewConnectionHandlerParams::new_unstaked(
741 packet_sender.clone(),
742 max_connections_per_peer,
743 stats.clone(),
744 ),
745 |(pubkey, stake, total_stake, max_stake, min_stake)| {
746 let min_stake_ratio =
749 1_f64 / (max_streams_per_ms * STREAM_THROTTLING_INTERVAL_MS) as f64;
750 let stake_ratio = stake as f64 / total_stake as f64;
751 let peer_type = if stake_ratio < min_stake_ratio {
752 ConnectionPeerType::Unstaked
754 } else {
755 ConnectionPeerType::Staked(stake)
756 };
757 NewConnectionHandlerParams {
758 packet_sender,
759 remote_pubkey: Some(pubkey),
760 peer_type,
761 total_stake,
762 max_connections_per_peer,
763 stats: stats.clone(),
764 max_stake,
765 min_stake,
766 }
767 },
768 );
769
770 match params.peer_type {
771 ConnectionPeerType::Staked(stake) => {
772 let mut connection_table_l = staked_connection_table.lock().await;
773
774 if connection_table_l.total_size >= max_staked_connections {
775 let num_pruned =
776 connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake);
777 stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
778 }
779
780 if connection_table_l.total_size < max_staked_connections {
781 if let Ok(()) = handle_and_cache_new_connection(
782 client_connection_tracker,
783 new_connection,
784 connection_table_l,
785 staked_connection_table.clone(),
786 ¶ms,
787 wait_for_chunk_timeout,
788 stream_load_ema.clone(),
789 ) {
790 stats
791 .connection_added_from_staked_peer
792 .fetch_add(1, Ordering::Relaxed);
793 }
794 } else {
795 if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
799 client_connection_tracker,
800 new_connection,
801 unstaked_connection_table.clone(),
802 max_unstaked_connections,
803 ¶ms,
804 wait_for_chunk_timeout,
805 stream_load_ema.clone(),
806 )
807 .await
808 {
809 stats
810 .connection_added_from_staked_peer
811 .fetch_add(1, Ordering::Relaxed);
812 } else {
813 stats
814 .connection_add_failed_on_pruning
815 .fetch_add(1, Ordering::Relaxed);
816 stats
817 .connection_add_failed_staked_node
818 .fetch_add(1, Ordering::Relaxed);
819 }
820 }
821 }
822 ConnectionPeerType::Unstaked => {
823 if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
824 client_connection_tracker,
825 new_connection,
826 unstaked_connection_table.clone(),
827 max_unstaked_connections,
828 ¶ms,
829 wait_for_chunk_timeout,
830 stream_load_ema.clone(),
831 )
832 .await
833 {
834 stats
835 .connection_added_from_unstaked_peer
836 .fetch_add(1, Ordering::Relaxed);
837 } else {
838 stats
839 .connection_add_failed_unstaked_node
840 .fetch_add(1, Ordering::Relaxed);
841 }
842 }
843 }
844 }
845 Err(e) => {
846 handle_connection_error(e, &stats, from);
847 }
848 }
849 } else {
850 stats
851 .connection_setup_timeout
852 .fetch_add(1, Ordering::Relaxed);
853 }
854}
855
856fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
857 debug!("error: {:?} from: {:?}", e, from);
858 stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
859 match e {
860 quinn::ConnectionError::TimedOut => {
861 stats
862 .connection_setup_error_timed_out
863 .fetch_add(1, Ordering::Relaxed);
864 }
865 quinn::ConnectionError::ConnectionClosed(_) => {
866 stats
867 .connection_setup_error_closed
868 .fetch_add(1, Ordering::Relaxed);
869 }
870 quinn::ConnectionError::TransportError(_) => {
871 stats
872 .connection_setup_error_transport
873 .fetch_add(1, Ordering::Relaxed);
874 }
875 quinn::ConnectionError::ApplicationClosed(_) => {
876 stats
877 .connection_setup_error_app_closed
878 .fetch_add(1, Ordering::Relaxed);
879 }
880 quinn::ConnectionError::Reset => {
881 stats
882 .connection_setup_error_reset
883 .fetch_add(1, Ordering::Relaxed);
884 }
885 quinn::ConnectionError::LocallyClosed => {
886 stats
887 .connection_setup_error_locally_closed
888 .fetch_add(1, Ordering::Relaxed);
889 }
890 _ => {}
891 }
892}
893
894async fn packet_batch_sender(
897 packet_sender: Sender<PacketBatch>,
898 packet_receiver: AsyncReceiver<PacketAccumulator>,
899 exit: Arc<AtomicBool>,
900 stats: Arc<StreamerStats>,
901 coalesce: Duration,
902) {
903 trace!("enter packet_batch_sender");
904 let mut batch_start_time = Instant::now();
905 loop {
906 let mut packet_perf_measure: Vec<([u8; 64], Instant)> = Vec::default();
907 let mut packet_batch = PacketBatch::with_capacity(PACKETS_PER_BATCH);
908 let mut total_bytes: usize = 0;
909
910 stats
911 .total_packet_batches_allocated
912 .fetch_add(1, Ordering::Relaxed);
913 stats
914 .total_packets_allocated
915 .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed);
916
917 loop {
918 if exit.load(Ordering::Relaxed) {
919 return;
920 }
921 let elapsed = batch_start_time.elapsed();
922 if packet_batch.len() >= PACKETS_PER_BATCH
923 || (!packet_batch.is_empty() && elapsed >= coalesce)
924 {
925 let len = packet_batch.len();
926 track_streamer_fetch_packet_performance(&packet_perf_measure, &stats);
927
928 if let Err(e) = packet_sender.send(packet_batch) {
929 stats
930 .total_packet_batch_send_err
931 .fetch_add(1, Ordering::Relaxed);
932 trace!("Send error: {}", e);
933 } else {
934 stats
935 .total_packet_batches_sent
936 .fetch_add(1, Ordering::Relaxed);
937
938 stats
939 .total_packets_sent_to_consumer
940 .fetch_add(len, Ordering::Relaxed);
941
942 stats
943 .total_bytes_sent_to_consumer
944 .fetch_add(total_bytes, Ordering::Relaxed);
945
946 trace!("Sent {} packet batch", len);
947 }
948 break;
949 }
950
951 let timeout_res = if !packet_batch.is_empty() {
952 timeout(coalesce - elapsed, packet_receiver.recv()).await
954 } else {
955 Ok(packet_receiver.recv().await)
963 };
964
965 if let Ok(Ok(packet_accumulator)) = timeout_res {
966 if packet_batch.is_empty() {
968 batch_start_time = Instant::now();
969 }
970
971 unsafe {
972 packet_batch.set_len(packet_batch.len() + 1);
973 }
974
975 let i = packet_batch.len() - 1;
976 *packet_batch[i].meta_mut() = packet_accumulator.meta;
977 let num_chunks = packet_accumulator.chunks.len();
978 let mut offset = 0;
979 for chunk in packet_accumulator.chunks {
980 packet_batch[i].buffer_mut()[offset..offset + chunk.len()]
981 .copy_from_slice(&chunk);
982 offset += chunk.len();
983 }
984
985 total_bytes += packet_batch[i].meta().size;
986
987 if let Some(signature) = signature_if_should_track_packet(&packet_batch[i])
988 .ok()
989 .flatten()
990 {
991 packet_perf_measure.push((*signature, packet_accumulator.start_time));
992 packet_batch[i].meta_mut().set_track_performance(true);
994 }
995 stats
996 .total_chunks_processed_by_batcher
997 .fetch_add(num_chunks, Ordering::Relaxed);
998 }
999 }
1000 }
1001}
1002
1003fn track_streamer_fetch_packet_performance(
1004 packet_perf_measure: &[([u8; 64], Instant)],
1005 stats: &StreamerStats,
1006) {
1007 if packet_perf_measure.is_empty() {
1008 return;
1009 }
1010 let mut measure = Measure::start("track_perf");
1011 let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
1012
1013 let now = Instant::now();
1014 for (signature, start_time) in packet_perf_measure {
1015 let duration = now.duration_since(*start_time);
1016 debug!(
1017 "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
1018 Signature::from(*signature)
1019 );
1020 process_sampled_packets_us_hist
1021 .increment(duration.as_micros() as u64)
1022 .unwrap();
1023 }
1024
1025 drop(process_sampled_packets_us_hist);
1026 measure.stop();
1027 stats
1028 .perf_track_overhead_us
1029 .fetch_add(measure.as_us(), Ordering::Relaxed);
1030}
1031
1032async fn handle_connection(
1033 connection: Connection,
1034 remote_addr: SocketAddr,
1035 last_update: Arc<AtomicU64>,
1036 connection_table: Arc<Mutex<ConnectionTable>>,
1037 cancel: CancellationToken,
1038 params: NewConnectionHandlerParams,
1039 wait_for_chunk_timeout: Duration,
1040 stream_load_ema: Arc<StakedStreamLoadEMA>,
1041 stream_counter: Arc<ConnectionStreamCounter>,
1042) {
1043 let NewConnectionHandlerParams {
1044 packet_sender,
1045 peer_type,
1046 remote_pubkey,
1047 stats,
1048 total_stake,
1049 ..
1050 } = params;
1051
1052 debug!(
1053 "quic new connection {} streams: {} connections: {}",
1054 remote_addr,
1055 stats.total_streams.load(Ordering::Relaxed),
1056 stats.total_connections.load(Ordering::Relaxed),
1057 );
1058 stats.total_connections.fetch_add(1, Ordering::Relaxed);
1059
1060 'conn: loop {
1061 let mut stream = select! {
1064 stream = connection.accept_uni() => match stream {
1065 Ok(stream) => stream,
1066 Err(e) => {
1067 debug!("stream error: {:?}", e);
1068 break;
1069 }
1070 },
1071 _ = cancel.cancelled() => break,
1072 };
1073
1074 let max_streams_per_throttling_interval =
1075 stream_load_ema.available_load_capacity_in_throttling_duration(peer_type, total_stake);
1076
1077 let throttle_interval_start = stream_counter.reset_throttling_params_if_needed();
1078 let streams_read_in_throttle_interval = stream_counter.stream_count.load(Ordering::Relaxed);
1079 if streams_read_in_throttle_interval >= max_streams_per_throttling_interval {
1080 let throttle_duration =
1083 STREAM_THROTTLING_INTERVAL.saturating_sub(throttle_interval_start.elapsed());
1084
1085 if !throttle_duration.is_zero() {
1086 debug!("Throttling stream from {remote_addr:?}, peer type: {:?}, total stake: {}, \
1087 max_streams_per_interval: {max_streams_per_throttling_interval}, read_interval_streams: {streams_read_in_throttle_interval} \
1088 throttle_duration: {throttle_duration:?}",
1089 peer_type, total_stake);
1090 stats.throttled_streams.fetch_add(1, Ordering::Relaxed);
1091 match peer_type {
1092 ConnectionPeerType::Unstaked => {
1093 stats
1094 .throttled_unstaked_streams
1095 .fetch_add(1, Ordering::Relaxed);
1096 }
1097 ConnectionPeerType::Staked(_) => {
1098 stats
1099 .throttled_staked_streams
1100 .fetch_add(1, Ordering::Relaxed);
1101 }
1102 }
1103 sleep(throttle_duration).await;
1104 }
1105 }
1106 stream_load_ema.increment_load(peer_type);
1107 stream_counter.stream_count.fetch_add(1, Ordering::Relaxed);
1108 stats.total_streams.fetch_add(1, Ordering::Relaxed);
1109 stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
1110
1111 let mut meta = Meta::default();
1112 meta.set_socket_addr(&remote_addr);
1113 meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
1114 let mut accum = PacketAccumulator::new(meta);
1115
1116 let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
1125
1126 loop {
1127 let n_chunks = match tokio::select! {
1131 chunk = tokio::time::timeout(
1132 wait_for_chunk_timeout,
1133 stream.read_chunks(&mut chunks)) => chunk,
1134
1135 _ = cancel.cancelled() => break,
1137 } {
1138 Ok(Ok(chunk)) => chunk.unwrap_or(0),
1140 Ok(Err(e)) => {
1142 debug!("Received stream error: {:?}", e);
1143 stats
1144 .total_stream_read_errors
1145 .fetch_add(1, Ordering::Relaxed);
1146 break;
1147 }
1148 Err(_) => {
1150 debug!("Timeout in receiving on stream");
1151 stats
1152 .total_stream_read_timeouts
1153 .fetch_add(1, Ordering::Relaxed);
1154 break;
1155 }
1156 };
1157
1158 match handle_chunks(
1159 chunks.iter().take(n_chunks).cloned(),
1161 &mut accum,
1162 &packet_sender,
1163 &stats,
1164 peer_type,
1165 )
1166 .await
1167 {
1168 Ok(StreamState::Finished) => {
1170 last_update.store(timing::timestamp(), Ordering::Relaxed);
1171 break;
1172 }
1173 Ok(StreamState::Receiving) => {}
1175 Err(_) => {
1176 connection.close(
1178 CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
1179 CONNECTION_CLOSE_REASON_INVALID_STREAM,
1180 );
1181 stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1182 stream_load_ema.update_ema_if_needed();
1183 break 'conn;
1184 }
1185 }
1186 }
1187
1188 stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1189 stream_load_ema.update_ema_if_needed();
1190 }
1191
1192 let stable_id = connection.stable_id();
1193 let removed_connection_count = connection_table.lock().await.remove_connection(
1194 ConnectionTableKey::new(remote_addr.ip(), remote_pubkey),
1195 remote_addr.port(),
1196 stable_id,
1197 );
1198 if removed_connection_count > 0 {
1199 stats
1200 .connection_removed
1201 .fetch_add(removed_connection_count, Ordering::Relaxed);
1202 } else {
1203 stats
1204 .connection_remove_failed
1205 .fetch_add(1, Ordering::Relaxed);
1206 }
1207 stats.total_connections.fetch_sub(1, Ordering::Relaxed);
1208}
1209
1210enum StreamState {
1211 Receiving,
1213 Finished,
1215}
1216
1217async fn handle_chunks(
1222 chunks: impl ExactSizeIterator<Item = Bytes>,
1223 accum: &mut PacketAccumulator,
1224 packet_sender: &AsyncSender<PacketAccumulator>,
1225 stats: &StreamerStats,
1226 peer_type: ConnectionPeerType,
1227) -> Result<StreamState, ()> {
1228 let n_chunks = chunks.len();
1229 for chunk in chunks {
1230 accum.meta.size += chunk.len();
1231 if accum.meta.size > PACKET_DATA_SIZE {
1232 stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
1236 debug!("invalid stream size {}", accum.meta.size);
1237 return Err(());
1238 }
1239 accum.chunks.push(chunk);
1240 if peer_type.is_staked() {
1241 stats
1242 .total_staked_chunks_received
1243 .fetch_add(1, Ordering::Relaxed);
1244 } else {
1245 stats
1246 .total_unstaked_chunks_received
1247 .fetch_add(1, Ordering::Relaxed);
1248 }
1249 }
1250
1251 if n_chunks != 0 {
1253 return Ok(StreamState::Receiving);
1254 }
1255
1256 if accum.chunks.is_empty() {
1257 debug!("stream is empty");
1258 stats
1259 .total_packet_batches_none
1260 .fetch_add(1, Ordering::Relaxed);
1261 return Err(());
1262 }
1263
1264 let bytes_sent = accum.meta.size;
1266 let chunks_sent = accum.chunks.len();
1267
1268 if let Err(err) = packet_sender.send(accum.clone()).await {
1269 stats
1270 .total_handle_chunk_to_packet_batcher_send_err
1271 .fetch_add(1, Ordering::Relaxed);
1272 trace!("packet batch send error {:?}", err);
1273 } else {
1274 stats
1275 .total_packets_sent_for_batching
1276 .fetch_add(1, Ordering::Relaxed);
1277 stats
1278 .total_bytes_sent_for_batching
1279 .fetch_add(bytes_sent, Ordering::Relaxed);
1280 stats
1281 .total_chunks_sent_for_batching
1282 .fetch_add(chunks_sent, Ordering::Relaxed);
1283
1284 match peer_type {
1285 ConnectionPeerType::Unstaked => {
1286 stats
1287 .total_unstaked_packets_sent_for_batching
1288 .fetch_add(1, Ordering::Relaxed);
1289 }
1290 ConnectionPeerType::Staked(_) => {
1291 stats
1292 .total_staked_packets_sent_for_batching
1293 .fetch_add(1, Ordering::Relaxed);
1294 }
1295 }
1296
1297 trace!("sent {} byte packet for batching", bytes_sent);
1298 }
1299
1300 Ok(StreamState::Finished)
1301}
1302
1303#[derive(Debug)]
1304struct ConnectionEntry {
1305 cancel: CancellationToken,
1306 peer_type: ConnectionPeerType,
1307 last_update: Arc<AtomicU64>,
1308 port: u16,
1309 _client_connection_tracker: ClientConnectionTracker,
1311 connection: Option<Connection>,
1312 stream_counter: Arc<ConnectionStreamCounter>,
1313}
1314
1315impl ConnectionEntry {
1316 fn new(
1317 cancel: CancellationToken,
1318 peer_type: ConnectionPeerType,
1319 last_update: Arc<AtomicU64>,
1320 port: u16,
1321 client_connection_tracker: ClientConnectionTracker,
1322 connection: Option<Connection>,
1323 stream_counter: Arc<ConnectionStreamCounter>,
1324 ) -> Self {
1325 Self {
1326 cancel,
1327 peer_type,
1328 last_update,
1329 port,
1330 _client_connection_tracker: client_connection_tracker,
1331 connection,
1332 stream_counter,
1333 }
1334 }
1335
1336 fn last_update(&self) -> u64 {
1337 self.last_update.load(Ordering::Relaxed)
1338 }
1339
1340 fn stake(&self) -> u64 {
1341 match self.peer_type {
1342 ConnectionPeerType::Unstaked => 0,
1343 ConnectionPeerType::Staked(stake) => stake,
1344 }
1345 }
1346}
1347
1348impl Drop for ConnectionEntry {
1349 fn drop(&mut self) {
1350 if let Some(conn) = self.connection.take() {
1351 conn.close(
1352 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1353 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1354 );
1355 }
1356 self.cancel.cancel();
1357 }
1358}
1359
1360#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
1361enum ConnectionTableKey {
1362 IP(IpAddr),
1363 Pubkey(Pubkey),
1364}
1365
1366impl ConnectionTableKey {
1367 fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
1368 maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
1369 ConnectionTableKey::Pubkey(pubkey)
1370 })
1371 }
1372}
1373
1374struct ConnectionTable {
1376 table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry>>,
1377 total_size: usize,
1378}
1379
1380impl ConnectionTable {
1383 fn new() -> Self {
1384 Self {
1385 table: IndexMap::default(),
1386 total_size: 0,
1387 }
1388 }
1389
1390 fn prune_oldest(&mut self, max_size: usize) -> usize {
1391 let mut num_pruned = 0;
1392 let key = |(_, connections): &(_, &Vec<_>)| {
1393 connections.iter().map(ConnectionEntry::last_update).min()
1394 };
1395 while self.total_size.saturating_sub(num_pruned) > max_size {
1396 match self.table.values().enumerate().min_by_key(key) {
1397 None => break,
1398 Some((index, connections)) => {
1399 num_pruned += connections.len();
1400 self.table.swap_remove_index(index);
1401 }
1402 }
1403 }
1404 self.total_size = self.total_size.saturating_sub(num_pruned);
1405 num_pruned
1406 }
1407
1408 fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
1413 let num_pruned = std::iter::once(self.table.len())
1414 .filter(|&size| size > 0)
1415 .flat_map(|size| {
1416 let mut rng = thread_rng();
1417 repeat_with(move || rng.gen_range(0..size))
1418 })
1419 .map(|index| {
1420 let connection = self.table[index].first();
1421 let stake = connection.map(|connection: &ConnectionEntry| connection.stake());
1422 (index, stake)
1423 })
1424 .take(sample_size)
1425 .min_by_key(|&(_, stake)| stake)
1426 .filter(|&(_, stake)| stake < Some(threshold_stake))
1427 .and_then(|(index, _)| self.table.swap_remove_index(index))
1428 .map(|(_, connections)| connections.len())
1429 .unwrap_or_default();
1430 self.total_size = self.total_size.saturating_sub(num_pruned);
1431 num_pruned
1432 }
1433
1434 fn try_add_connection(
1435 &mut self,
1436 key: ConnectionTableKey,
1437 port: u16,
1438 client_connection_tracker: ClientConnectionTracker,
1439 connection: Option<Connection>,
1440 peer_type: ConnectionPeerType,
1441 last_update: u64,
1442 max_connections_per_peer: usize,
1443 ) -> Option<(
1444 Arc<AtomicU64>,
1445 CancellationToken,
1446 Arc<ConnectionStreamCounter>,
1447 )> {
1448 let connection_entry = self.table.entry(key).or_default();
1449 let has_connection_capacity = connection_entry
1450 .len()
1451 .checked_add(1)
1452 .map(|c| c <= max_connections_per_peer)
1453 .unwrap_or(false);
1454 if has_connection_capacity {
1455 let cancel = CancellationToken::new();
1456 let last_update = Arc::new(AtomicU64::new(last_update));
1457 let stream_counter = connection_entry
1458 .first()
1459 .map(|entry| entry.stream_counter.clone())
1460 .unwrap_or(Arc::new(ConnectionStreamCounter::new()));
1461 connection_entry.push(ConnectionEntry::new(
1462 cancel.clone(),
1463 peer_type,
1464 last_update.clone(),
1465 port,
1466 client_connection_tracker,
1467 connection,
1468 stream_counter.clone(),
1469 ));
1470 self.total_size += 1;
1471 Some((last_update, cancel, stream_counter))
1472 } else {
1473 if let Some(connection) = connection {
1474 connection.close(
1475 CONNECTION_CLOSE_CODE_TOO_MANY.into(),
1476 CONNECTION_CLOSE_REASON_TOO_MANY,
1477 );
1478 }
1479 None
1480 }
1481 }
1482
1483 fn remove_connection(&mut self, key: ConnectionTableKey, port: u16, stable_id: usize) -> usize {
1485 if let Entry::Occupied(mut e) = self.table.entry(key) {
1486 let e_ref = e.get_mut();
1487 let old_size = e_ref.len();
1488
1489 e_ref.retain(|connection_entry| {
1490 connection_entry.port != port
1496 || connection_entry
1497 .connection
1498 .as_ref()
1499 .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
1500 .is_some()
1501 });
1502 let new_size = e_ref.len();
1503 if e_ref.is_empty() {
1504 e.swap_remove_entry();
1505 }
1506 let connections_removed = old_size.saturating_sub(new_size);
1507 self.total_size = self.total_size.saturating_sub(connections_removed);
1508 connections_removed
1509 } else {
1510 0
1511 }
1512 }
1513}
1514
1515struct EndpointAccept<'a> {
1516 endpoint: usize,
1517 accept: Accept<'a>,
1518}
1519
1520impl<'a> Future for EndpointAccept<'a> {
1521 type Output = (Option<quinn::Incoming>, usize);
1522
1523 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
1524 let i = self.endpoint;
1525 unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
1529 .poll(cx)
1530 .map(|r| (r, i))
1531 }
1532}
1533
1534#[cfg(test)]
1535pub mod test {
1536 use {
1537 super::*,
1538 crate::{
1539 nonblocking::{
1540 quic::compute_max_allowed_uni_streams,
1541 testing_utilities::{
1542 get_client_config, make_client_endpoint, setup_quic_server,
1543 SpawnTestServerResult, TestServerConfig,
1544 },
1545 },
1546 quic::{MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS},
1547 },
1548 assert_matches::assert_matches,
1549 async_channel::unbounded as async_unbounded,
1550 crossbeam_channel::{unbounded, Receiver},
1551 quinn::{ApplicationClose, ConnectionError},
1552 solana_sdk::{net::DEFAULT_TPU_COALESCE, signature::Keypair, signer::Signer},
1553 std::collections::HashMap,
1554 tokio::time::sleep,
1555 };
1556
1557 pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
1558 let conn1 = make_client_endpoint(&server_address, None).await;
1559 let total = 30;
1560 for i in 0..total {
1561 let mut s1 = conn1.open_uni().await.unwrap();
1562 s1.write_all(&[0u8]).await.unwrap();
1563 s1.finish().unwrap();
1564 info!("done {}", i);
1565 sleep(Duration::from_millis(1000)).await;
1566 }
1567 let mut received = 0;
1568 loop {
1569 if let Ok(_x) = receiver.try_recv() {
1570 received += 1;
1571 info!("got {}", received);
1572 } else {
1573 sleep(Duration::from_millis(500)).await;
1574 }
1575 if received >= total {
1576 break;
1577 }
1578 }
1579 }
1580
1581 pub async fn check_block_multiple_connections(server_address: SocketAddr) {
1582 let conn1 = make_client_endpoint(&server_address, None).await;
1583 let conn2 = make_client_endpoint(&server_address, None).await;
1584 let mut s1 = conn1.open_uni().await.unwrap();
1585 let s2 = conn2.open_uni().await;
1586 if let Ok(mut s2) = s2 {
1587 s1.write_all(&[0u8]).await.unwrap();
1588 s1.finish().unwrap();
1589 let data = vec![1u8; PACKET_DATA_SIZE * 2];
1593 s2.write_all(&data)
1594 .await
1595 .expect_err("shouldn't be able to open 2 connections");
1596 } else {
1597 assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
1601 }
1602 }
1603
1604 pub async fn check_multiple_streams(
1605 receiver: Receiver<PacketBatch>,
1606 server_address: SocketAddr,
1607 ) {
1608 let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1609 let conn2 = Arc::new(make_client_endpoint(&server_address, None).await);
1610 let mut num_expected_packets = 0;
1611 for i in 0..10 {
1612 info!("sending: {}", i);
1613 let c1 = conn1.clone();
1614 let c2 = conn2.clone();
1615 let mut s1 = c1.open_uni().await.unwrap();
1616 let mut s2 = c2.open_uni().await.unwrap();
1617 s1.write_all(&[0u8]).await.unwrap();
1618 s1.finish().unwrap();
1619 s2.write_all(&[0u8]).await.unwrap();
1620 s2.finish().unwrap();
1621 num_expected_packets += 2;
1622 sleep(Duration::from_millis(200)).await;
1623 }
1624 let mut all_packets = vec![];
1625 let now = Instant::now();
1626 let mut total_packets = 0;
1627 while now.elapsed().as_secs() < 10 {
1628 if let Ok(packets) = receiver.try_recv() {
1629 total_packets += packets.len();
1630 all_packets.push(packets)
1631 } else {
1632 sleep(Duration::from_secs(1)).await;
1633 }
1634 if total_packets == num_expected_packets {
1635 break;
1636 }
1637 }
1638 for batch in all_packets {
1639 for p in batch.iter() {
1640 assert_eq!(p.meta().size, 1);
1641 }
1642 }
1643 assert_eq!(total_packets, num_expected_packets);
1644 }
1645
1646 pub async fn check_multiple_writes(
1647 receiver: Receiver<PacketBatch>,
1648 server_address: SocketAddr,
1649 client_keypair: Option<&Keypair>,
1650 ) {
1651 let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1652
1653 let num_bytes = PACKET_DATA_SIZE;
1655 let num_expected_packets = 1;
1656 let mut s1 = conn1.open_uni().await.unwrap();
1657 for _ in 0..num_bytes {
1658 s1.write_all(&[0u8]).await.unwrap();
1659 }
1660 s1.finish().unwrap();
1661
1662 let mut all_packets = vec![];
1663 let now = Instant::now();
1664 let mut total_packets = 0;
1665 while now.elapsed().as_secs() < 5 {
1666 if let Ok(packets) = receiver.try_recv() {
1669 total_packets += packets.len();
1670 all_packets.push(packets)
1671 } else {
1672 sleep(Duration::from_secs(1)).await;
1673 }
1674 if total_packets >= num_expected_packets {
1675 break;
1676 }
1677 }
1678 for batch in all_packets {
1679 for p in batch.iter() {
1680 assert_eq!(p.meta().size, num_bytes);
1681 }
1682 }
1683 assert_eq!(total_packets, num_expected_packets);
1684 }
1685
1686 pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
1687 let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1688
1689 if let Ok(mut s1) = conn1.open_uni().await {
1691 for _ in 0..PACKET_DATA_SIZE {
1692 s1.write_all(&[0u8]).await.unwrap_or_default();
1694 }
1695 s1.finish().unwrap_or_default();
1696 s1.stopped().await.unwrap_err();
1697 }
1698 }
1699
1700 #[tokio::test(flavor = "multi_thread")]
1701 async fn test_quic_server_exit() {
1702 let SpawnTestServerResult {
1703 join_handle,
1704 exit,
1705 receiver: _,
1706 server_address: _,
1707 stats: _,
1708 } = setup_quic_server(None, TestServerConfig::default());
1709 exit.store(true, Ordering::Relaxed);
1710 join_handle.await.unwrap();
1711 }
1712
1713 #[tokio::test(flavor = "multi_thread")]
1714 async fn test_quic_timeout() {
1715 solana_logger::setup();
1716 let SpawnTestServerResult {
1717 join_handle,
1718 exit,
1719 receiver,
1720 server_address,
1721 stats: _,
1722 } = setup_quic_server(None, TestServerConfig::default());
1723
1724 check_timeout(receiver, server_address).await;
1725 exit.store(true, Ordering::Relaxed);
1726 join_handle.await.unwrap();
1727 }
1728
1729 #[tokio::test(flavor = "multi_thread")]
1730 async fn test_packet_batcher() {
1731 solana_logger::setup();
1732 let (pkt_batch_sender, pkt_batch_receiver) = unbounded();
1733 let (ptk_sender, pkt_receiver) = async_unbounded();
1734 let exit = Arc::new(AtomicBool::new(false));
1735 let stats = Arc::new(StreamerStats::default());
1736
1737 let handle = tokio::spawn(packet_batch_sender(
1738 pkt_batch_sender,
1739 pkt_receiver,
1740 exit.clone(),
1741 stats,
1742 DEFAULT_TPU_COALESCE,
1743 ));
1744
1745 let num_packets = 1000;
1746
1747 for _i in 0..num_packets {
1748 let mut meta = Meta::default();
1749 let bytes = Bytes::from("Hello world");
1750 let size = bytes.len();
1751 meta.size = size;
1752 let packet_accum = PacketAccumulator {
1753 meta,
1754 chunks: smallvec::smallvec![bytes],
1755 start_time: Instant::now(),
1756 };
1757 ptk_sender.send(packet_accum).await.unwrap();
1758 }
1759 let mut i = 0;
1760 let start = Instant::now();
1761 while i < num_packets && start.elapsed().as_secs() < 2 {
1762 if let Ok(batch) = pkt_batch_receiver.try_recv() {
1763 i += batch.len();
1764 } else {
1765 sleep(Duration::from_millis(1)).await;
1766 }
1767 }
1768 assert_eq!(i, num_packets);
1769 exit.store(true, Ordering::Relaxed);
1770 drop(ptk_sender);
1772 handle.await.unwrap();
1773 }
1774
1775 #[tokio::test(flavor = "multi_thread")]
1776 async fn test_quic_stream_timeout() {
1777 solana_logger::setup();
1778 let SpawnTestServerResult {
1779 join_handle,
1780 exit,
1781 receiver: _,
1782 server_address,
1783 stats,
1784 } = setup_quic_server(None, TestServerConfig::default());
1785
1786 let conn1 = make_client_endpoint(&server_address, None).await;
1787 assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1788 assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1789
1790 let mut s1 = conn1.open_uni().await.unwrap();
1792 s1.write_all(&[0u8]).await.unwrap_or_default();
1793
1794 let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
1796 sleep(sleep_time).await;
1797
1798 assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1800 assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1801
1802 assert!(s1.write_all(&[0u8]).await.is_err());
1805
1806 exit.store(true, Ordering::Relaxed);
1807 join_handle.await.unwrap();
1808 }
1809
1810 #[tokio::test(flavor = "multi_thread")]
1811 async fn test_quic_server_block_multiple_connections() {
1812 solana_logger::setup();
1813 let SpawnTestServerResult {
1814 join_handle,
1815 exit,
1816 receiver: _,
1817 server_address,
1818 stats: _,
1819 } = setup_quic_server(None, TestServerConfig::default());
1820 check_block_multiple_connections(server_address).await;
1821 exit.store(true, Ordering::Relaxed);
1822 join_handle.await.unwrap();
1823 }
1824
1825 #[tokio::test(flavor = "multi_thread")]
1826 async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
1827 solana_logger::setup();
1828
1829 let SpawnTestServerResult {
1830 join_handle,
1831 exit,
1832 receiver: _,
1833 server_address,
1834 stats,
1835 } = setup_quic_server(
1836 None,
1837 TestServerConfig {
1838 max_connections_per_peer: 2,
1839 ..Default::default()
1840 },
1841 );
1842
1843 let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
1844 let mut endpoint = quinn::Endpoint::new(
1845 EndpointConfig::default(),
1846 None,
1847 client_socket,
1848 Arc::new(TokioRuntime),
1849 )
1850 .unwrap();
1851 let default_keypair = Keypair::new();
1852 endpoint.set_default_client_config(get_client_config(&default_keypair));
1853 let conn1 = endpoint
1854 .connect(server_address, "localhost")
1855 .expect("Failed in connecting")
1856 .await
1857 .expect("Failed in waiting");
1858
1859 let conn2 = endpoint
1860 .connect(server_address, "localhost")
1861 .expect("Failed in connecting")
1862 .await
1863 .expect("Failed in waiting");
1864
1865 let mut s1 = conn1.open_uni().await.unwrap();
1866 s1.write_all(&[0u8]).await.unwrap();
1867 s1.finish().unwrap();
1868
1869 let mut s2 = conn2.open_uni().await.unwrap();
1870 conn1.close(
1871 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1872 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1873 );
1874
1875 let start = Instant::now();
1876 while stats.connection_removed.load(Ordering::Relaxed) != 1 {
1877 debug!("First connection not removed yet");
1878 sleep(Duration::from_millis(10)).await;
1879 }
1880 assert!(start.elapsed().as_secs() < 1);
1881
1882 s2.write_all(&[0u8]).await.unwrap();
1883 s2.finish().unwrap();
1884
1885 conn2.close(
1886 CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1887 CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1888 );
1889
1890 let start = Instant::now();
1891 while stats.connection_removed.load(Ordering::Relaxed) != 2 {
1892 debug!("Second connection not removed yet");
1893 sleep(Duration::from_millis(10)).await;
1894 }
1895 assert!(start.elapsed().as_secs() < 1);
1896
1897 exit.store(true, Ordering::Relaxed);
1898 join_handle.await.unwrap();
1899 }
1900
1901 #[tokio::test(flavor = "multi_thread")]
1902 async fn test_quic_server_multiple_writes() {
1903 solana_logger::setup();
1904 let SpawnTestServerResult {
1905 join_handle,
1906 exit,
1907 receiver,
1908 server_address,
1909 stats: _,
1910 } = setup_quic_server(None, TestServerConfig::default());
1911 check_multiple_writes(receiver, server_address, None).await;
1912 exit.store(true, Ordering::Relaxed);
1913 join_handle.await.unwrap();
1914 }
1915
1916 #[tokio::test(flavor = "multi_thread")]
1917 async fn test_quic_server_staked_connection_removal() {
1918 solana_logger::setup();
1919
1920 let client_keypair = Keypair::new();
1921 let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
1922 let staked_nodes = StakedNodes::new(
1923 Arc::new(stakes),
1924 HashMap::<Pubkey, u64>::default(), );
1926 let SpawnTestServerResult {
1927 join_handle,
1928 exit,
1929 receiver,
1930 server_address,
1931 stats,
1932 } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1933 check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1934 exit.store(true, Ordering::Relaxed);
1935 join_handle.await.unwrap();
1936 sleep(Duration::from_millis(100)).await;
1937 assert_eq!(
1938 stats
1939 .connection_added_from_unstaked_peer
1940 .load(Ordering::Relaxed),
1941 0
1942 );
1943 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1944 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1945 }
1946
1947 #[tokio::test(flavor = "multi_thread")]
1948 async fn test_quic_server_zero_staked_connection_removal() {
1949 solana_logger::setup();
1951
1952 let client_keypair = Keypair::new();
1953 let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
1954 let staked_nodes = StakedNodes::new(
1955 Arc::new(stakes),
1956 HashMap::<Pubkey, u64>::default(), );
1958 let SpawnTestServerResult {
1959 join_handle,
1960 exit,
1961 receiver,
1962 server_address,
1963 stats,
1964 } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1965 check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1966 exit.store(true, Ordering::Relaxed);
1967 join_handle.await.unwrap();
1968 sleep(Duration::from_millis(100)).await;
1969 assert_eq!(
1970 stats
1971 .connection_added_from_staked_peer
1972 .load(Ordering::Relaxed),
1973 0
1974 );
1975 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1976 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1977 }
1978
1979 #[tokio::test(flavor = "multi_thread")]
1980 async fn test_quic_server_unstaked_connection_removal() {
1981 solana_logger::setup();
1982 let SpawnTestServerResult {
1983 join_handle,
1984 exit,
1985 receiver,
1986 server_address,
1987 stats,
1988 } = setup_quic_server(None, TestServerConfig::default());
1989 check_multiple_writes(receiver, server_address, None).await;
1990 exit.store(true, Ordering::Relaxed);
1991 join_handle.await.unwrap();
1992 sleep(Duration::from_millis(100)).await;
1993 assert_eq!(
1994 stats
1995 .connection_added_from_staked_peer
1996 .load(Ordering::Relaxed),
1997 0
1998 );
1999 assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
2000 assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
2001 }
2002
2003 #[tokio::test(flavor = "multi_thread")]
2004 async fn test_quic_server_unstaked_node_connect_failure() {
2005 solana_logger::setup();
2006 let s = UdpSocket::bind("127.0.0.1:0").unwrap();
2007 let exit = Arc::new(AtomicBool::new(false));
2008 let (sender, _) = unbounded();
2009 let keypair = Keypair::new();
2010 let server_address = s.local_addr().unwrap();
2011 let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
2012 let SpawnNonBlockingServerResult {
2013 endpoints: _,
2014 stats: _,
2015 thread: t,
2016 max_concurrent_connections: _,
2017 } = spawn_server(
2018 "quic_streamer_test",
2019 s,
2020 &keypair,
2021 sender,
2022 exit.clone(),
2023 1,
2024 staked_nodes,
2025 MAX_STAKED_CONNECTIONS,
2026 0, DEFAULT_MAX_STREAMS_PER_MS,
2028 DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE,
2029 DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
2030 DEFAULT_TPU_COALESCE,
2031 )
2032 .unwrap();
2033
2034 check_unstaked_node_connect_failure(server_address).await;
2035 exit.store(true, Ordering::Relaxed);
2036 t.await.unwrap();
2037 }
2038
2039 #[tokio::test(flavor = "multi_thread")]
2040 async fn test_quic_server_multiple_streams() {
2041 solana_logger::setup();
2042 let s = UdpSocket::bind("127.0.0.1:0").unwrap();
2043 let exit = Arc::new(AtomicBool::new(false));
2044 let (sender, receiver) = unbounded();
2045 let keypair = Keypair::new();
2046 let server_address = s.local_addr().unwrap();
2047 let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
2048 let SpawnNonBlockingServerResult {
2049 endpoints: _,
2050 stats,
2051 thread: t,
2052 max_concurrent_connections: _,
2053 } = spawn_server(
2054 "quic_streamer_test",
2055 s,
2056 &keypair,
2057 sender,
2058 exit.clone(),
2059 2,
2060 staked_nodes,
2061 MAX_STAKED_CONNECTIONS,
2062 MAX_UNSTAKED_CONNECTIONS,
2063 DEFAULT_MAX_STREAMS_PER_MS,
2064 DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE,
2065 DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
2066 DEFAULT_TPU_COALESCE,
2067 )
2068 .unwrap();
2069
2070 check_multiple_streams(receiver, server_address).await;
2071 assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
2072 assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
2073 assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
2074 assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2075 exit.store(true, Ordering::Relaxed);
2076 t.await.unwrap();
2077 assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
2078 assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2079 }
2080
2081 #[test]
2082 fn test_prune_table_with_ip() {
2083 use std::net::Ipv4Addr;
2084 solana_logger::setup();
2085 let mut table = ConnectionTable::new();
2086 let mut num_entries = 5;
2087 let max_connections_per_peer = 10;
2088 let sockets: Vec<_> = (0..num_entries)
2089 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2090 .collect();
2091 let stats = Arc::new(StreamerStats::default());
2092 for (i, socket) in sockets.iter().enumerate() {
2093 table
2094 .try_add_connection(
2095 ConnectionTableKey::IP(socket.ip()),
2096 socket.port(),
2097 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2098 None,
2099 ConnectionPeerType::Unstaked,
2100 i as u64,
2101 max_connections_per_peer,
2102 )
2103 .unwrap();
2104 }
2105 num_entries += 1;
2106 table
2107 .try_add_connection(
2108 ConnectionTableKey::IP(sockets[0].ip()),
2109 sockets[0].port(),
2110 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2111 None,
2112 ConnectionPeerType::Unstaked,
2113 5,
2114 max_connections_per_peer,
2115 )
2116 .unwrap();
2117
2118 let new_size = 3;
2119 let pruned = table.prune_oldest(new_size);
2120 assert_eq!(pruned, num_entries as usize - new_size);
2121 for v in table.table.values() {
2122 for x in v {
2123 assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
2124 }
2125 }
2126 assert_eq!(table.table.len(), new_size);
2127 assert_eq!(table.total_size, new_size);
2128 for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
2129 table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2130 }
2131 assert_eq!(table.total_size, 0);
2132 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2133 }
2134
2135 #[test]
2136 fn test_prune_table_with_unique_pubkeys() {
2137 solana_logger::setup();
2138 let mut table = ConnectionTable::new();
2139
2140 let num_entries = 15;
2143 let max_connections_per_peer = 10;
2144 let stats = Arc::new(StreamerStats::default());
2145
2146 let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
2147 for (i, pubkey) in pubkeys.iter().enumerate() {
2148 table
2149 .try_add_connection(
2150 ConnectionTableKey::Pubkey(*pubkey),
2151 0,
2152 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2153 None,
2154 ConnectionPeerType::Unstaked,
2155 i as u64,
2156 max_connections_per_peer,
2157 )
2158 .unwrap();
2159 }
2160
2161 let new_size = 3;
2162 let pruned = table.prune_oldest(new_size);
2163 assert_eq!(pruned, num_entries as usize - new_size);
2164 assert_eq!(table.table.len(), new_size);
2165 assert_eq!(table.total_size, new_size);
2166 for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
2167 table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
2168 }
2169 assert_eq!(table.total_size, 0);
2170 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2171 }
2172
2173 #[test]
2174 fn test_prune_table_with_non_unique_pubkeys() {
2175 solana_logger::setup();
2176 let mut table = ConnectionTable::new();
2177
2178 let max_connections_per_peer = 10;
2179 let pubkey = Pubkey::new_unique();
2180 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2181
2182 (0..max_connections_per_peer).for_each(|i| {
2183 table
2184 .try_add_connection(
2185 ConnectionTableKey::Pubkey(pubkey),
2186 0,
2187 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2188 None,
2189 ConnectionPeerType::Unstaked,
2190 i as u64,
2191 max_connections_per_peer,
2192 )
2193 .unwrap();
2194 });
2195
2196 assert!(table
2199 .try_add_connection(
2200 ConnectionTableKey::Pubkey(pubkey),
2201 0,
2202 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2203 None,
2204 ConnectionPeerType::Unstaked,
2205 10,
2206 max_connections_per_peer,
2207 )
2208 .is_none());
2209
2210 let num_entries = max_connections_per_peer + 1;
2212 let pubkey2 = Pubkey::new_unique();
2213 assert!(table
2214 .try_add_connection(
2215 ConnectionTableKey::Pubkey(pubkey2),
2216 0,
2217 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2218 None,
2219 ConnectionPeerType::Unstaked,
2220 10,
2221 max_connections_per_peer,
2222 )
2223 .is_some());
2224
2225 assert_eq!(table.total_size, num_entries);
2226
2227 let new_max_size = 3;
2228 let pruned = table.prune_oldest(new_max_size);
2229 assert!(pruned >= num_entries - new_max_size);
2230 assert!(table.table.len() <= new_max_size);
2231 assert!(table.total_size <= new_max_size);
2232
2233 table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
2234 assert_eq!(table.total_size, 0);
2235 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2236 }
2237
2238 #[test]
2239 fn test_prune_table_random() {
2240 use std::net::Ipv4Addr;
2241 solana_logger::setup();
2242 let mut table = ConnectionTable::new();
2243 let num_entries = 5;
2244 let max_connections_per_peer = 10;
2245 let sockets: Vec<_> = (0..num_entries)
2246 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2247 .collect();
2248 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2249
2250 for (i, socket) in sockets.iter().enumerate() {
2251 table
2252 .try_add_connection(
2253 ConnectionTableKey::IP(socket.ip()),
2254 socket.port(),
2255 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2256 None,
2257 ConnectionPeerType::Staked((i + 1) as u64),
2258 i as u64,
2259 max_connections_per_peer,
2260 )
2261 .unwrap();
2262 }
2263
2264 let pruned = table.prune_random(2, 0);
2267 assert_eq!(pruned, 0);
2268
2269 let pruned = table.prune_random(
2272 2, num_entries as u64 + 1, );
2275 assert_eq!(pruned, 1);
2276 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
2278 }
2279
2280 #[test]
2281 fn test_remove_connections() {
2282 use std::net::Ipv4Addr;
2283 solana_logger::setup();
2284 let mut table = ConnectionTable::new();
2285 let num_ips = 5;
2286 let max_connections_per_peer = 10;
2287 let mut sockets: Vec<_> = (0..num_ips)
2288 .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2289 .collect();
2290 let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2291
2292 for (i, socket) in sockets.iter().enumerate() {
2293 table
2294 .try_add_connection(
2295 ConnectionTableKey::IP(socket.ip()),
2296 socket.port(),
2297 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2298 None,
2299 ConnectionPeerType::Unstaked,
2300 (i * 2) as u64,
2301 max_connections_per_peer,
2302 )
2303 .unwrap();
2304
2305 table
2306 .try_add_connection(
2307 ConnectionTableKey::IP(socket.ip()),
2308 socket.port(),
2309 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2310 None,
2311 ConnectionPeerType::Unstaked,
2312 (i * 2 + 1) as u64,
2313 max_connections_per_peer,
2314 )
2315 .unwrap();
2316 }
2317
2318 let single_connection_addr =
2319 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
2320 table
2321 .try_add_connection(
2322 ConnectionTableKey::IP(single_connection_addr.ip()),
2323 single_connection_addr.port(),
2324 ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2325 None,
2326 ConnectionPeerType::Unstaked,
2327 (num_ips * 2) as u64,
2328 max_connections_per_peer,
2329 )
2330 .unwrap();
2331
2332 let zero_connection_addr =
2333 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
2334
2335 sockets.push(single_connection_addr);
2336 sockets.push(zero_connection_addr);
2337
2338 for socket in sockets.iter() {
2339 table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2340 }
2341 assert_eq!(table.total_size, 0);
2342 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2343 }
2344
2345 #[test]
2346
2347 fn test_max_allowed_uni_streams() {
2348 assert_eq!(
2349 compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0),
2350 QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2351 );
2352 assert_eq!(
2353 compute_max_allowed_uni_streams(ConnectionPeerType::Staked(10), 0),
2354 QUIC_MIN_STAKED_CONCURRENT_STREAMS
2355 );
2356 let delta =
2357 (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
2358 assert_eq!(
2359 compute_max_allowed_uni_streams(ConnectionPeerType::Staked(1000), 10000),
2360 QUIC_MAX_STAKED_CONCURRENT_STREAMS,
2361 );
2362 assert_eq!(
2363 compute_max_allowed_uni_streams(ConnectionPeerType::Staked(100), 10000),
2364 ((delta / (100_f64)) as usize + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
2365 .min(QUIC_MAX_STAKED_CONCURRENT_STREAMS)
2366 );
2367 assert_eq!(
2368 compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10000),
2369 QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2370 );
2371 }
2372
2373 #[test]
2374 fn test_cacluate_receive_window_ratio_for_staked_node() {
2375 let mut max_stake = 10000;
2376 let mut min_stake = 0;
2377 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, min_stake);
2378 assert_eq!(ratio, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO);
2379
2380 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2381 let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
2382 assert_eq!(ratio, max_ratio);
2383
2384 let ratio =
2385 compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake / 2);
2386 let average_ratio =
2387 (QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO + QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO) / 2;
2388 assert_eq!(ratio, average_ratio);
2389
2390 max_stake = 10000;
2391 min_stake = 10000;
2392 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2393 assert_eq!(ratio, max_ratio);
2394
2395 max_stake = 0;
2396 min_stake = 0;
2397 let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2398 assert_eq!(ratio, max_ratio);
2399
2400 max_stake = 1000;
2401 min_stake = 10;
2402 let ratio =
2403 compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake + 10);
2404 assert_eq!(ratio, max_ratio);
2405 }
2406
2407 #[tokio::test(flavor = "multi_thread")]
2408 async fn test_throttling_check_no_packet_drop() {
2409 solana_logger::setup_with_default_filter();
2410
2411 let SpawnTestServerResult {
2412 join_handle,
2413 exit,
2414 receiver,
2415 server_address,
2416 stats,
2417 } = setup_quic_server(None, TestServerConfig::default());
2418
2419 let client_connection = make_client_endpoint(&server_address, None).await;
2420
2421 let expected_num_txs = 100;
2423 let start_time = tokio::time::Instant::now();
2424 for i in 0..expected_num_txs {
2425 let mut send_stream = client_connection.open_uni().await.unwrap();
2426 let data = format!("{i}").into_bytes();
2427 send_stream.write_all(&data).await.unwrap();
2428 send_stream.finish().unwrap();
2429 }
2430 let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
2431 info!("Elapsed sending: {elapsed_sending}");
2432
2433 let start_time = tokio::time::Instant::now();
2435 let mut num_txs_received = 0;
2436 while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
2437 if let Ok(packets) = receiver.try_recv() {
2438 num_txs_received += packets.len();
2439 } else {
2440 sleep(Duration::from_millis(100)).await;
2441 }
2442 }
2443 assert_eq!(expected_num_txs, num_txs_received);
2444
2445 exit.store(true, Ordering::Relaxed);
2447 join_handle.await.unwrap();
2448
2449 assert_eq!(
2450 stats.total_new_streams.load(Ordering::Relaxed),
2451 expected_num_txs
2452 );
2453 assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
2454 }
2455
2456 #[test]
2457 fn test_client_connection_tracker() {
2458 let stats = Arc::new(StreamerStats::default());
2459 let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
2460 assert!(tracker_1.is_ok());
2461 assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
2462 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
2463 drop(tracker_1);
2465 assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2466 }
2467
2468 #[tokio::test(flavor = "multi_thread")]
2469 async fn test_client_connection_close_invalid_stream() {
2470 let SpawnTestServerResult {
2471 join_handle,
2472 server_address,
2473 stats,
2474 exit,
2475 ..
2476 } = setup_quic_server(None, TestServerConfig::default());
2477
2478 let client_connection = make_client_endpoint(&server_address, None).await;
2479
2480 let mut send_stream = client_connection.open_uni().await.unwrap();
2481 send_stream
2482 .write_all(&[42; PACKET_DATA_SIZE + 1])
2483 .await
2484 .unwrap();
2485 match client_connection.closed().await {
2486 ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
2487 assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
2488 assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
2489 }
2490 _ => panic!("unexpected close"),
2491 }
2492 assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
2493 exit.store(true, Ordering::Relaxed);
2494 join_handle.await.unwrap();
2495 }
2496}