solana_streamer/nonblocking/
quic.rs

1use {
2    crate::{
3        nonblocking::{
4            connection_rate_limiter::{ConnectionRateLimiter, TotalConnectionRateLimiter},
5            stream_throttle::{
6                ConnectionStreamCounter, StakedStreamLoadEMA, STREAM_THROTTLING_INTERVAL,
7                STREAM_THROTTLING_INTERVAL_MS,
8            },
9        },
10        quic::{configure_server, QuicServerError, StreamerStats},
11        streamer::StakedNodes,
12        tls_certificates::get_pubkey_from_tls_certificate,
13    },
14    async_channel::{
15        unbounded as async_unbounded, Receiver as AsyncReceiver, Sender as AsyncSender,
16    },
17    bytes::Bytes,
18    crossbeam_channel::Sender,
19    futures::{stream::FuturesUnordered, Future, StreamExt as _},
20    indexmap::map::{Entry, IndexMap},
21    percentage::Percentage,
22    quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime, VarInt},
23    quinn_proto::VarIntBoundsExceeded,
24    rand::{thread_rng, Rng},
25    smallvec::SmallVec,
26    solana_measure::measure::Measure,
27    solana_perf::packet::{PacketBatch, PACKETS_PER_BATCH},
28    solana_sdk::{
29        packet::{Meta, PACKET_DATA_SIZE},
30        pubkey::Pubkey,
31        quic::{
32            QUIC_CONNECTION_HANDSHAKE_TIMEOUT, QUIC_MAX_STAKED_CONCURRENT_STREAMS,
33            QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
34            QUIC_MIN_STAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO,
35            QUIC_TOTAL_STAKED_CONCURRENT_STREAMS, QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO,
36        },
37        signature::{Keypair, Signature},
38        timing,
39    },
40    solana_transaction_metrics_tracker::signature_if_should_track_packet,
41    std::{
42        array,
43        fmt,
44        iter::repeat_with,
45        net::{IpAddr, SocketAddr, UdpSocket},
46        pin::Pin,
47        // CAUTION: be careful not to introduce any awaits while holding an RwLock.
48        sync::{
49            atomic::{AtomicBool, AtomicU64, Ordering},
50            Arc, RwLock,
51        },
52        task::Poll,
53        time::{Duration, Instant},
54    },
55    tokio::{
56        // CAUTION: It's kind of sketch that we're mixing async and sync locks (see the RwLock above).
57        // This is done so that sync code can also access the stake table.
58        // Make sure we don't hold a sync lock across an await - including the await to
59        // lock an async Mutex. This does not happen now and should not happen as long as we
60        // don't hold an async Mutex and sync RwLock at the same time (currently true)
61        // but if we do, the scope of the RwLock must always be a subset of the async Mutex
62        // (i.e. lock order is always async Mutex -> RwLock). Also, be careful not to
63        // introduce any other awaits while holding the RwLock.
64        select,
65        sync::{Mutex, MutexGuard},
66        task::JoinHandle,
67        time::{sleep, timeout},
68    },
69    tokio_util::sync::CancellationToken,
70};
71
72pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
73
74pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
75
76const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
77const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
78
79const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
80const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
81
82const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
83const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] = b"exceed_max_stream_count";
84
85const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
86const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
87
88const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
89const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
90
91/// Limit to 250K PPS
92pub const DEFAULT_MAX_STREAMS_PER_MS: u64 = 250;
93
94/// The new connections per minute from a particular IP address.
95/// Heuristically set to the default maximum concurrent connections
96/// per IP address. Might be adjusted later.
97pub const DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE: u64 = 8;
98
99/// Total new connection counts per second. Heuristically taken from
100/// the default staked and unstaked connection limits. Might be adjusted
101/// later.
102const TOTAL_CONNECTIONS_PER_SECOND: u64 = 2500;
103
104/// The threshold of the size of the connection rate limiter map. When
105/// the map size is above this, we will trigger a cleanup of older
106/// entries used by past requests.
107const CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD: usize = 100_000;
108
109// A struct to accumulate the bytes making up
110// a packet, along with their offsets, and the
111// packet metadata. We use this accumulator to avoid
112// multiple copies of the Bytes (when building up
113// the Packet and then when copying the Packet into a PacketBatch)
114#[derive(Clone)]
115struct PacketAccumulator {
116    pub meta: Meta,
117    pub chunks: SmallVec<[Bytes; 2]>,
118    pub start_time: Instant,
119}
120
121impl PacketAccumulator {
122    fn new(meta: Meta) -> Self {
123        Self {
124            meta,
125            chunks: SmallVec::default(),
126            start_time: Instant::now(),
127        }
128    }
129}
130
131#[derive(Copy, Clone, Debug)]
132pub enum ConnectionPeerType {
133    Unstaked,
134    Staked(u64),
135}
136
137impl ConnectionPeerType {
138    pub(crate) fn is_staked(&self) -> bool {
139        matches!(self, ConnectionPeerType::Staked(_))
140    }
141}
142
143pub struct SpawnNonBlockingServerResult {
144    pub endpoints: Vec<Endpoint>,
145    pub stats: Arc<StreamerStats>,
146    pub thread: JoinHandle<()>,
147    pub max_concurrent_connections: usize,
148}
149
150#[allow(clippy::too_many_arguments)]
151pub fn spawn_server(
152    name: &'static str,
153    sock: UdpSocket,
154    keypair: &Keypair,
155    packet_sender: Sender<PacketBatch>,
156    exit: Arc<AtomicBool>,
157    max_connections_per_peer: usize,
158    staked_nodes: Arc<RwLock<StakedNodes>>,
159    max_staked_connections: usize,
160    max_unstaked_connections: usize,
161    max_streams_per_ms: u64,
162    max_connections_per_ipaddr_per_min: u64,
163    wait_for_chunk_timeout: Duration,
164    coalesce: Duration,
165) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
166    spawn_server_multi(
167        name,
168        vec![sock],
169        keypair,
170        packet_sender,
171        exit,
172        max_connections_per_peer,
173        staked_nodes,
174        max_staked_connections,
175        max_unstaked_connections,
176        max_streams_per_ms,
177        max_connections_per_ipaddr_per_min,
178        wait_for_chunk_timeout,
179        coalesce,
180    )
181}
182
183#[allow(clippy::too_many_arguments, clippy::type_complexity)]
184pub fn spawn_server_multi(
185    name: &'static str,
186    sockets: Vec<UdpSocket>,
187    keypair: &Keypair,
188    packet_sender: Sender<PacketBatch>,
189    exit: Arc<AtomicBool>,
190    max_connections_per_peer: usize,
191    staked_nodes: Arc<RwLock<StakedNodes>>,
192    max_staked_connections: usize,
193    max_unstaked_connections: usize,
194    max_streams_per_ms: u64,
195    max_connections_per_ipaddr_per_min: u64,
196    wait_for_chunk_timeout: Duration,
197    coalesce: Duration,
198) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
199    info!("Start {name} quic server on {sockets:?}");
200    let concurrent_connections = max_staked_connections + max_unstaked_connections;
201    let max_concurrent_connections = concurrent_connections + concurrent_connections / 4;
202    let (config, _) = configure_server(keypair)?;
203
204    let endpoints = sockets
205        .into_iter()
206        .map(|sock| {
207            Endpoint::new(
208                EndpointConfig::default(),
209                Some(config.clone()),
210                sock,
211                Arc::new(TokioRuntime),
212            )
213            .map_err(QuicServerError::EndpointFailed)
214        })
215        .collect::<Result<Vec<_>, _>>()?;
216    let stats = Arc::<StreamerStats>::default();
217    let handle = tokio::spawn(run_server(
218        name,
219        endpoints.clone(),
220        packet_sender,
221        exit,
222        max_connections_per_peer,
223        staked_nodes,
224        max_staked_connections,
225        max_unstaked_connections,
226        max_streams_per_ms,
227        max_connections_per_ipaddr_per_min,
228        stats.clone(),
229        wait_for_chunk_timeout,
230        coalesce,
231        max_concurrent_connections,
232    ));
233    Ok(SpawnNonBlockingServerResult {
234        endpoints,
235        stats,
236        thread: handle,
237        max_concurrent_connections,
238    })
239}
240
241/// struct ease tracking connections of all stages, so that we do not have to
242/// litter the code with open connection tracking. This is added into the
243/// connection table as part of the ConnectionEntry. The reference is auto
244/// reduced when it is dropped.
245
246struct ClientConnectionTracker {
247    stats: Arc<StreamerStats>,
248}
249
250/// This is required by ConnectionEntry for supporting debug format.
251impl fmt::Debug for ClientConnectionTracker {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        f.debug_struct("StreamerClientConnection")
254            .field(
255                "open_connections:",
256                &self.stats.open_connections.load(Ordering::Relaxed),
257            )
258            .finish()
259    }
260}
261
262impl Drop for ClientConnectionTracker {
263    /// When this is dropped, reduce the open connection count.
264    fn drop(&mut self) {
265        self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
266    }
267}
268
269impl ClientConnectionTracker {
270    /// Check the max_concurrent_connections limit and if it is within the limit
271    /// create ClientConnectionTracker and increment open connection count. Otherwise returns Err
272    fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
273        let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
274        if open_connections >= max_concurrent_connections {
275            stats.open_connections.fetch_sub(1, Ordering::Relaxed);
276            debug!(
277                "There are too many concurrent connections opened already: open: {}, max: {}",
278                open_connections, max_concurrent_connections
279            );
280            return Err(());
281        }
282
283        Ok(Self { stats })
284    }
285}
286
287#[allow(clippy::too_many_arguments)]
288async fn run_server(
289    name: &'static str,
290    endpoints: Vec<Endpoint>,
291    packet_sender: Sender<PacketBatch>,
292    exit: Arc<AtomicBool>,
293    max_connections_per_peer: usize,
294    staked_nodes: Arc<RwLock<StakedNodes>>,
295    max_staked_connections: usize,
296    max_unstaked_connections: usize,
297    max_streams_per_ms: u64,
298    max_connections_per_ipaddr_per_min: u64,
299    stats: Arc<StreamerStats>,
300    wait_for_chunk_timeout: Duration,
301    coalesce: Duration,
302    max_concurrent_connections: usize,
303) {
304    let rate_limiter = ConnectionRateLimiter::new(max_connections_per_ipaddr_per_min);
305    let overall_connection_rate_limiter =
306        TotalConnectionRateLimiter::new(TOTAL_CONNECTIONS_PER_SECOND);
307
308    const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
309    debug!("spawn quic server");
310    let mut last_datapoint = Instant::now();
311    let unstaked_connection_table: Arc<Mutex<ConnectionTable>> =
312        Arc::new(Mutex::new(ConnectionTable::new()));
313    let stream_load_ema = Arc::new(StakedStreamLoadEMA::new(
314        stats.clone(),
315        max_unstaked_connections,
316        max_streams_per_ms,
317    ));
318    stats
319        .quic_endpoints_count
320        .store(endpoints.len(), Ordering::Relaxed);
321    let staked_connection_table: Arc<Mutex<ConnectionTable>> =
322        Arc::new(Mutex::new(ConnectionTable::new()));
323    let (sender, receiver) = async_unbounded();
324    tokio::spawn(packet_batch_sender(
325        packet_sender,
326        receiver,
327        exit.clone(),
328        stats.clone(),
329        coalesce,
330    ));
331
332    let mut accepts = endpoints
333        .iter()
334        .enumerate()
335        .map(|(i, incoming)| {
336            Box::pin(EndpointAccept {
337                accept: incoming.accept(),
338                endpoint: i,
339            })
340        })
341        .collect::<FuturesUnordered<_>>();
342
343    while !exit.load(Ordering::Relaxed) {
344        let timeout_connection = select! {
345            ready = accepts.next() => {
346                if let Some((connecting, i)) = ready {
347                    accepts.push(
348                        Box::pin(EndpointAccept {
349                            accept: endpoints[i].accept(),
350                            endpoint: i,
351                        }
352                    ));
353                    Ok(connecting)
354                } else {
355                    // we can't really get here - we never poll an empty FuturesUnordered
356                    continue
357                }
358            }
359            _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
360                Err(())
361            }
362        };
363
364        if last_datapoint.elapsed().as_secs() >= 5 {
365            stats.report(name);
366            last_datapoint = Instant::now();
367        }
368
369        if let Ok(Some(incoming)) = timeout_connection {
370            stats
371                .total_incoming_connection_attempts
372                .fetch_add(1, Ordering::Relaxed);
373
374            let remote_address = incoming.remote_address();
375
376            // first check overall connection rate limit:
377            if !overall_connection_rate_limiter.is_allowed() {
378                debug!(
379                    "Reject connection from {:?} -- total rate limiting exceeded",
380                    remote_address.ip()
381                );
382                stats
383                    .connection_rate_limited_across_all
384                    .fetch_add(1, Ordering::Relaxed);
385                incoming.ignore();
386                continue;
387            }
388
389            if rate_limiter.len() > CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD {
390                rate_limiter.retain_recent();
391            }
392            stats
393                .connection_rate_limiter_length
394                .store(rate_limiter.len(), Ordering::Relaxed);
395            debug!("Got a connection {remote_address:?}");
396            if !rate_limiter.is_allowed(&remote_address.ip()) {
397                debug!(
398                    "Reject connection from {:?} -- rate limiting exceeded",
399                    remote_address
400                );
401                stats
402                    .connection_rate_limited_per_ipaddr
403                    .fetch_add(1, Ordering::Relaxed);
404                incoming.ignore();
405                continue;
406            }
407
408            let Ok(client_connection_tracker) =
409                ClientConnectionTracker::new(stats.clone(), max_concurrent_connections)
410            else {
411                stats
412                    .refused_connections_too_many_open_connections
413                    .fetch_add(1, Ordering::Relaxed);
414                incoming.refuse();
415                continue;
416            };
417
418            stats
419                .outstanding_incoming_connection_attempts
420                .fetch_add(1, Ordering::Relaxed);
421            let connecting = incoming.accept();
422            match connecting {
423                Ok(connecting) => {
424                    tokio::spawn(setup_connection(
425                        connecting,
426                        client_connection_tracker,
427                        unstaked_connection_table.clone(),
428                        staked_connection_table.clone(),
429                        sender.clone(),
430                        max_connections_per_peer,
431                        staked_nodes.clone(),
432                        max_staked_connections,
433                        max_unstaked_connections,
434                        max_streams_per_ms,
435                        stats.clone(),
436                        wait_for_chunk_timeout,
437                        stream_load_ema.clone(),
438                    ));
439                }
440                Err(err) => {
441                    debug!("Incoming::accept(): error {:?}", err);
442                }
443            }
444        } else {
445            debug!("accept(): Timed out waiting for connection");
446        }
447    }
448}
449
450fn prune_unstaked_connection_table(
451    unstaked_connection_table: &mut ConnectionTable,
452    max_unstaked_connections: usize,
453    stats: Arc<StreamerStats>,
454) {
455    if unstaked_connection_table.total_size >= max_unstaked_connections {
456        const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
457        let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
458
459        let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
460        let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
461        stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
462    }
463}
464
465pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
466    // Use the client cert only if it is self signed and the chain length is 1.
467    connection
468        .peer_identity()?
469        .downcast::<Vec<rustls::pki_types::CertificateDer>>()
470        .ok()
471        .filter(|certs| certs.len() == 1)?
472        .first()
473        .and_then(get_pubkey_from_tls_certificate)
474}
475
476fn get_connection_stake(
477    connection: &Connection,
478    staked_nodes: &RwLock<StakedNodes>,
479) -> Option<(Pubkey, u64, u64, u64, u64)> {
480    let pubkey = get_remote_pubkey(connection)?;
481    debug!("Peer public key is {pubkey:?}");
482    let staked_nodes = staked_nodes.read().unwrap();
483    Some((
484        pubkey,
485        staked_nodes.get_node_stake(&pubkey)?,
486        staked_nodes.total_stake(),
487        staked_nodes.max_stake(),
488        staked_nodes.min_stake(),
489    ))
490}
491
492pub fn compute_max_allowed_uni_streams(peer_type: ConnectionPeerType, total_stake: u64) -> usize {
493    match peer_type {
494        ConnectionPeerType::Staked(peer_stake) => {
495            // No checked math for f64 type. So let's explicitly check for 0 here
496            if total_stake == 0 || peer_stake > total_stake {
497                warn!(
498                    "Invalid stake values: peer_stake: {:?}, total_stake: {:?}",
499                    peer_stake, total_stake,
500                );
501
502                QUIC_MIN_STAKED_CONCURRENT_STREAMS
503            } else {
504                let delta = (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS
505                    - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
506
507                (((peer_stake as f64 / total_stake as f64) * delta) as usize
508                    + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
509                    .clamp(
510                        QUIC_MIN_STAKED_CONCURRENT_STREAMS,
511                        QUIC_MAX_STAKED_CONCURRENT_STREAMS,
512                    )
513            }
514        }
515        ConnectionPeerType::Unstaked => QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
516    }
517}
518
519enum ConnectionHandlerError {
520    ConnectionAddError,
521    MaxStreamError,
522}
523
524#[derive(Clone)]
525struct NewConnectionHandlerParams {
526    // In principle, the code can be made to work with a crossbeam channel
527    // as long as we're careful never to use a blocking recv or send call
528    // but I've found that it's simply too easy to accidentally block
529    // in async code when using the crossbeam channel, so for the sake of maintainability,
530    // we're sticking with an async channel
531    packet_sender: AsyncSender<PacketAccumulator>,
532    remote_pubkey: Option<Pubkey>,
533    peer_type: ConnectionPeerType,
534    total_stake: u64,
535    max_connections_per_peer: usize,
536    stats: Arc<StreamerStats>,
537    max_stake: u64,
538    min_stake: u64,
539}
540
541impl NewConnectionHandlerParams {
542    fn new_unstaked(
543        packet_sender: AsyncSender<PacketAccumulator>,
544        max_connections_per_peer: usize,
545        stats: Arc<StreamerStats>,
546    ) -> NewConnectionHandlerParams {
547        NewConnectionHandlerParams {
548            packet_sender,
549            remote_pubkey: None,
550            peer_type: ConnectionPeerType::Unstaked,
551            total_stake: 0,
552            max_connections_per_peer,
553            stats,
554            max_stake: 0,
555            min_stake: 0,
556        }
557    }
558}
559
560fn handle_and_cache_new_connection(
561    client_connection_tracker: ClientConnectionTracker,
562    connection: Connection,
563    mut connection_table_l: MutexGuard<ConnectionTable>,
564    connection_table: Arc<Mutex<ConnectionTable>>,
565    params: &NewConnectionHandlerParams,
566    wait_for_chunk_timeout: Duration,
567    stream_load_ema: Arc<StakedStreamLoadEMA>,
568) -> Result<(), ConnectionHandlerError> {
569    if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
570        params.peer_type,
571        params.total_stake,
572    ) as u64)
573    {
574        let remote_addr = connection.remote_address();
575        let receive_window =
576            compute_recieve_window(params.max_stake, params.min_stake, params.peer_type);
577
578        debug!(
579            "Peer type {:?}, total stake {}, max streams {} receive_window {:?} from peer {}",
580            params.peer_type,
581            params.total_stake,
582            max_uni_streams.into_inner(),
583            receive_window,
584            remote_addr,
585        );
586
587        if let Some((last_update, cancel_connection, stream_counter)) = connection_table_l
588            .try_add_connection(
589                ConnectionTableKey::new(remote_addr.ip(), params.remote_pubkey),
590                remote_addr.port(),
591                client_connection_tracker,
592                Some(connection.clone()),
593                params.peer_type,
594                timing::timestamp(),
595                params.max_connections_per_peer,
596            )
597        {
598            drop(connection_table_l);
599
600            if let Ok(receive_window) = receive_window {
601                connection.set_receive_window(receive_window);
602            }
603            connection.set_max_concurrent_uni_streams(max_uni_streams);
604
605            tokio::spawn(handle_connection(
606                connection,
607                remote_addr,
608                last_update,
609                connection_table,
610                cancel_connection,
611                params.clone(),
612                wait_for_chunk_timeout,
613                stream_load_ema,
614                stream_counter,
615            ));
616            Ok(())
617        } else {
618            params
619                .stats
620                .connection_add_failed
621                .fetch_add(1, Ordering::Relaxed);
622            Err(ConnectionHandlerError::ConnectionAddError)
623        }
624    } else {
625        connection.close(
626            CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT.into(),
627            CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT,
628        );
629        params
630            .stats
631            .connection_add_failed_invalid_stream_count
632            .fetch_add(1, Ordering::Relaxed);
633        Err(ConnectionHandlerError::MaxStreamError)
634    }
635}
636
637async fn prune_unstaked_connections_and_add_new_connection(
638    client_connection_tracker: ClientConnectionTracker,
639    connection: Connection,
640    connection_table: Arc<Mutex<ConnectionTable>>,
641    max_connections: usize,
642    params: &NewConnectionHandlerParams,
643    wait_for_chunk_timeout: Duration,
644    stream_load_ema: Arc<StakedStreamLoadEMA>,
645) -> Result<(), ConnectionHandlerError> {
646    let stats = params.stats.clone();
647    if max_connections > 0 {
648        let connection_table_clone = connection_table.clone();
649        let mut connection_table = connection_table.lock().await;
650        prune_unstaked_connection_table(&mut connection_table, max_connections, stats);
651        handle_and_cache_new_connection(
652            client_connection_tracker,
653            connection,
654            connection_table,
655            connection_table_clone,
656            params,
657            wait_for_chunk_timeout,
658            stream_load_ema,
659        )
660    } else {
661        connection.close(
662            CONNECTION_CLOSE_CODE_DISALLOWED.into(),
663            CONNECTION_CLOSE_REASON_DISALLOWED,
664        );
665        Err(ConnectionHandlerError::ConnectionAddError)
666    }
667}
668
669/// Calculate the ratio for per connection receive window from a staked peer
670fn compute_receive_window_ratio_for_staked_node(max_stake: u64, min_stake: u64, stake: u64) -> u64 {
671    // Testing shows the maximum througput from a connection is achieved at receive_window =
672    // PACKET_DATA_SIZE * 10. Beyond that, there is not much gain. We linearly map the
673    // stake to the ratio range from QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO to
674    // QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO. Where the linear algebra of finding the ratio 'r'
675    // for stake 's' is,
676    // r(s) = a * s + b. Given the max_stake, min_stake, max_ratio, min_ratio, we can find
677    // a and b.
678
679    if stake > max_stake {
680        return QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
681    }
682
683    let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
684    let min_ratio = QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO;
685    if max_stake > min_stake {
686        let a = (max_ratio - min_ratio) as f64 / (max_stake - min_stake) as f64;
687        let b = max_ratio as f64 - ((max_stake as f64) * a);
688        let ratio = (a * stake as f64) + b;
689        ratio.round() as u64
690    } else {
691        QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO
692    }
693}
694
695fn compute_recieve_window(
696    max_stake: u64,
697    min_stake: u64,
698    peer_type: ConnectionPeerType,
699) -> Result<VarInt, VarIntBoundsExceeded> {
700    match peer_type {
701        ConnectionPeerType::Unstaked => {
702            VarInt::from_u64(PACKET_DATA_SIZE as u64 * QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO)
703        }
704        ConnectionPeerType::Staked(peer_stake) => {
705            let ratio =
706                compute_receive_window_ratio_for_staked_node(max_stake, min_stake, peer_stake);
707            VarInt::from_u64(PACKET_DATA_SIZE as u64 * ratio)
708        }
709    }
710}
711
712#[allow(clippy::too_many_arguments)]
713async fn setup_connection(
714    connecting: Connecting,
715    client_connection_tracker: ClientConnectionTracker,
716    unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
717    staked_connection_table: Arc<Mutex<ConnectionTable>>,
718    packet_sender: AsyncSender<PacketAccumulator>,
719    max_connections_per_peer: usize,
720    staked_nodes: Arc<RwLock<StakedNodes>>,
721    max_staked_connections: usize,
722    max_unstaked_connections: usize,
723    max_streams_per_ms: u64,
724    stats: Arc<StreamerStats>,
725    wait_for_chunk_timeout: Duration,
726    stream_load_ema: Arc<StakedStreamLoadEMA>,
727) {
728    const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
729    let from = connecting.remote_address();
730    let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
731    stats
732        .outstanding_incoming_connection_attempts
733        .fetch_sub(1, Ordering::Relaxed);
734    if let Ok(connecting_result) = res {
735        match connecting_result {
736            Ok(new_connection) => {
737                stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
738
739                let params = get_connection_stake(&new_connection, &staked_nodes).map_or(
740                    NewConnectionHandlerParams::new_unstaked(
741                        packet_sender.clone(),
742                        max_connections_per_peer,
743                        stats.clone(),
744                    ),
745                    |(pubkey, stake, total_stake, max_stake, min_stake)| {
746                        // The heuristic is that the stake should be large engouh to have 1 stream pass throuh within one throttle
747                        // interval during which we allow max (MAX_STREAMS_PER_MS * STREAM_THROTTLING_INTERVAL_MS) streams.
748                        let min_stake_ratio =
749                            1_f64 / (max_streams_per_ms * STREAM_THROTTLING_INTERVAL_MS) as f64;
750                        let stake_ratio = stake as f64 / total_stake as f64;
751                        let peer_type = if stake_ratio < min_stake_ratio {
752                            // If it is a staked connection with ultra low stake ratio, treat it as unstaked.
753                            ConnectionPeerType::Unstaked
754                        } else {
755                            ConnectionPeerType::Staked(stake)
756                        };
757                        NewConnectionHandlerParams {
758                            packet_sender,
759                            remote_pubkey: Some(pubkey),
760                            peer_type,
761                            total_stake,
762                            max_connections_per_peer,
763                            stats: stats.clone(),
764                            max_stake,
765                            min_stake,
766                        }
767                    },
768                );
769
770                match params.peer_type {
771                    ConnectionPeerType::Staked(stake) => {
772                        let mut connection_table_l = staked_connection_table.lock().await;
773
774                        if connection_table_l.total_size >= max_staked_connections {
775                            let num_pruned =
776                                connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake);
777                            stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
778                        }
779
780                        if connection_table_l.total_size < max_staked_connections {
781                            if let Ok(()) = handle_and_cache_new_connection(
782                                client_connection_tracker,
783                                new_connection,
784                                connection_table_l,
785                                staked_connection_table.clone(),
786                                &params,
787                                wait_for_chunk_timeout,
788                                stream_load_ema.clone(),
789                            ) {
790                                stats
791                                    .connection_added_from_staked_peer
792                                    .fetch_add(1, Ordering::Relaxed);
793                            }
794                        } else {
795                            // If we couldn't prune a connection in the staked connection table, let's
796                            // put this connection in the unstaked connection table. If needed, prune a
797                            // connection from the unstaked connection table.
798                            if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
799                                client_connection_tracker,
800                                new_connection,
801                                unstaked_connection_table.clone(),
802                                max_unstaked_connections,
803                                &params,
804                                wait_for_chunk_timeout,
805                                stream_load_ema.clone(),
806                            )
807                            .await
808                            {
809                                stats
810                                    .connection_added_from_staked_peer
811                                    .fetch_add(1, Ordering::Relaxed);
812                            } else {
813                                stats
814                                    .connection_add_failed_on_pruning
815                                    .fetch_add(1, Ordering::Relaxed);
816                                stats
817                                    .connection_add_failed_staked_node
818                                    .fetch_add(1, Ordering::Relaxed);
819                            }
820                        }
821                    }
822                    ConnectionPeerType::Unstaked => {
823                        if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
824                            client_connection_tracker,
825                            new_connection,
826                            unstaked_connection_table.clone(),
827                            max_unstaked_connections,
828                            &params,
829                            wait_for_chunk_timeout,
830                            stream_load_ema.clone(),
831                        )
832                        .await
833                        {
834                            stats
835                                .connection_added_from_unstaked_peer
836                                .fetch_add(1, Ordering::Relaxed);
837                        } else {
838                            stats
839                                .connection_add_failed_unstaked_node
840                                .fetch_add(1, Ordering::Relaxed);
841                        }
842                    }
843                }
844            }
845            Err(e) => {
846                handle_connection_error(e, &stats, from);
847            }
848        }
849    } else {
850        stats
851            .connection_setup_timeout
852            .fetch_add(1, Ordering::Relaxed);
853    }
854}
855
856fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
857    debug!("error: {:?} from: {:?}", e, from);
858    stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
859    match e {
860        quinn::ConnectionError::TimedOut => {
861            stats
862                .connection_setup_error_timed_out
863                .fetch_add(1, Ordering::Relaxed);
864        }
865        quinn::ConnectionError::ConnectionClosed(_) => {
866            stats
867                .connection_setup_error_closed
868                .fetch_add(1, Ordering::Relaxed);
869        }
870        quinn::ConnectionError::TransportError(_) => {
871            stats
872                .connection_setup_error_transport
873                .fetch_add(1, Ordering::Relaxed);
874        }
875        quinn::ConnectionError::ApplicationClosed(_) => {
876            stats
877                .connection_setup_error_app_closed
878                .fetch_add(1, Ordering::Relaxed);
879        }
880        quinn::ConnectionError::Reset => {
881            stats
882                .connection_setup_error_reset
883                .fetch_add(1, Ordering::Relaxed);
884        }
885        quinn::ConnectionError::LocallyClosed => {
886            stats
887                .connection_setup_error_locally_closed
888                .fetch_add(1, Ordering::Relaxed);
889        }
890        _ => {}
891    }
892}
893
894// Holder(s) of the AsyncSender<PacketAccumulator> on the other end should not
895// wait for this function to exit to exit
896async fn packet_batch_sender(
897    packet_sender: Sender<PacketBatch>,
898    packet_receiver: AsyncReceiver<PacketAccumulator>,
899    exit: Arc<AtomicBool>,
900    stats: Arc<StreamerStats>,
901    coalesce: Duration,
902) {
903    trace!("enter packet_batch_sender");
904    let mut batch_start_time = Instant::now();
905    loop {
906        let mut packet_perf_measure: Vec<([u8; 64], Instant)> = Vec::default();
907        let mut packet_batch = PacketBatch::with_capacity(PACKETS_PER_BATCH);
908        let mut total_bytes: usize = 0;
909
910        stats
911            .total_packet_batches_allocated
912            .fetch_add(1, Ordering::Relaxed);
913        stats
914            .total_packets_allocated
915            .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed);
916
917        loop {
918            if exit.load(Ordering::Relaxed) {
919                return;
920            }
921            let elapsed = batch_start_time.elapsed();
922            if packet_batch.len() >= PACKETS_PER_BATCH
923                || (!packet_batch.is_empty() && elapsed >= coalesce)
924            {
925                let len = packet_batch.len();
926                track_streamer_fetch_packet_performance(&packet_perf_measure, &stats);
927
928                if let Err(e) = packet_sender.send(packet_batch) {
929                    stats
930                        .total_packet_batch_send_err
931                        .fetch_add(1, Ordering::Relaxed);
932                    trace!("Send error: {}", e);
933                } else {
934                    stats
935                        .total_packet_batches_sent
936                        .fetch_add(1, Ordering::Relaxed);
937
938                    stats
939                        .total_packets_sent_to_consumer
940                        .fetch_add(len, Ordering::Relaxed);
941
942                    stats
943                        .total_bytes_sent_to_consumer
944                        .fetch_add(total_bytes, Ordering::Relaxed);
945
946                    trace!("Sent {} packet batch", len);
947                }
948                break;
949            }
950
951            let timeout_res = if !packet_batch.is_empty() {
952                // If we get here, elapsed < coalesce (see above if condition)
953                timeout(coalesce - elapsed, packet_receiver.recv()).await
954            } else {
955                // Small bit of non-idealness here: the holder(s) of the other end
956                // of packet_receiver must drop it (without waiting for us to exit)
957                // or we have a chance of sleeping here forever
958                // and never polling exit. Not a huge deal in practice as the
959                // only time this happens is when we tear down the server
960                // and at that time the other end does indeed not wait for us
961                // to exit here
962                Ok(packet_receiver.recv().await)
963            };
964
965            if let Ok(Ok(packet_accumulator)) = timeout_res {
966                // Start the timeout from when the packet batch first becomes non-empty
967                if packet_batch.is_empty() {
968                    batch_start_time = Instant::now();
969                }
970
971                unsafe {
972                    packet_batch.set_len(packet_batch.len() + 1);
973                }
974
975                let i = packet_batch.len() - 1;
976                *packet_batch[i].meta_mut() = packet_accumulator.meta;
977                let num_chunks = packet_accumulator.chunks.len();
978                let mut offset = 0;
979                for chunk in packet_accumulator.chunks {
980                    packet_batch[i].buffer_mut()[offset..offset + chunk.len()]
981                        .copy_from_slice(&chunk);
982                    offset += chunk.len();
983                }
984
985                total_bytes += packet_batch[i].meta().size;
986
987                if let Some(signature) = signature_if_should_track_packet(&packet_batch[i])
988                    .ok()
989                    .flatten()
990                {
991                    packet_perf_measure.push((*signature, packet_accumulator.start_time));
992                    // we set the PERF_TRACK_PACKET on
993                    packet_batch[i].meta_mut().set_track_performance(true);
994                }
995                stats
996                    .total_chunks_processed_by_batcher
997                    .fetch_add(num_chunks, Ordering::Relaxed);
998            }
999        }
1000    }
1001}
1002
1003fn track_streamer_fetch_packet_performance(
1004    packet_perf_measure: &[([u8; 64], Instant)],
1005    stats: &StreamerStats,
1006) {
1007    if packet_perf_measure.is_empty() {
1008        return;
1009    }
1010    let mut measure = Measure::start("track_perf");
1011    let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
1012
1013    let now = Instant::now();
1014    for (signature, start_time) in packet_perf_measure {
1015        let duration = now.duration_since(*start_time);
1016        debug!(
1017            "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
1018            Signature::from(*signature)
1019        );
1020        process_sampled_packets_us_hist
1021            .increment(duration.as_micros() as u64)
1022            .unwrap();
1023    }
1024
1025    drop(process_sampled_packets_us_hist);
1026    measure.stop();
1027    stats
1028        .perf_track_overhead_us
1029        .fetch_add(measure.as_us(), Ordering::Relaxed);
1030}
1031
1032async fn handle_connection(
1033    connection: Connection,
1034    remote_addr: SocketAddr,
1035    last_update: Arc<AtomicU64>,
1036    connection_table: Arc<Mutex<ConnectionTable>>,
1037    cancel: CancellationToken,
1038    params: NewConnectionHandlerParams,
1039    wait_for_chunk_timeout: Duration,
1040    stream_load_ema: Arc<StakedStreamLoadEMA>,
1041    stream_counter: Arc<ConnectionStreamCounter>,
1042) {
1043    let NewConnectionHandlerParams {
1044        packet_sender,
1045        peer_type,
1046        remote_pubkey,
1047        stats,
1048        total_stake,
1049        ..
1050    } = params;
1051
1052    debug!(
1053        "quic new connection {} streams: {} connections: {}",
1054        remote_addr,
1055        stats.total_streams.load(Ordering::Relaxed),
1056        stats.total_connections.load(Ordering::Relaxed),
1057    );
1058    stats.total_connections.fetch_add(1, Ordering::Relaxed);
1059
1060    'conn: loop {
1061        // Wait for new streams. If the peer is disconnected we get a cancellation signal and stop
1062        // the connection task.
1063        let mut stream = select! {
1064            stream = connection.accept_uni() => match stream {
1065                Ok(stream) => stream,
1066                Err(e) => {
1067                    debug!("stream error: {:?}", e);
1068                    break;
1069                }
1070            },
1071            _ = cancel.cancelled() => break,
1072        };
1073
1074        let max_streams_per_throttling_interval =
1075            stream_load_ema.available_load_capacity_in_throttling_duration(peer_type, total_stake);
1076
1077        let throttle_interval_start = stream_counter.reset_throttling_params_if_needed();
1078        let streams_read_in_throttle_interval = stream_counter.stream_count.load(Ordering::Relaxed);
1079        if streams_read_in_throttle_interval >= max_streams_per_throttling_interval {
1080            // The peer is sending faster than we're willing to read. Sleep for what's
1081            // left of this read interval so the peer backs off.
1082            let throttle_duration =
1083                STREAM_THROTTLING_INTERVAL.saturating_sub(throttle_interval_start.elapsed());
1084
1085            if !throttle_duration.is_zero() {
1086                debug!("Throttling stream from {remote_addr:?}, peer type: {:?}, total stake: {}, \
1087                                    max_streams_per_interval: {max_streams_per_throttling_interval}, read_interval_streams: {streams_read_in_throttle_interval} \
1088                                    throttle_duration: {throttle_duration:?}",
1089                                    peer_type, total_stake);
1090                stats.throttled_streams.fetch_add(1, Ordering::Relaxed);
1091                match peer_type {
1092                    ConnectionPeerType::Unstaked => {
1093                        stats
1094                            .throttled_unstaked_streams
1095                            .fetch_add(1, Ordering::Relaxed);
1096                    }
1097                    ConnectionPeerType::Staked(_) => {
1098                        stats
1099                            .throttled_staked_streams
1100                            .fetch_add(1, Ordering::Relaxed);
1101                    }
1102                }
1103                sleep(throttle_duration).await;
1104            }
1105        }
1106        stream_load_ema.increment_load(peer_type);
1107        stream_counter.stream_count.fetch_add(1, Ordering::Relaxed);
1108        stats.total_streams.fetch_add(1, Ordering::Relaxed);
1109        stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
1110
1111        let mut meta = Meta::default();
1112        meta.set_socket_addr(&remote_addr);
1113        meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
1114        let mut accum = PacketAccumulator::new(meta);
1115
1116        // Virtually all small transactions will fit in 1 chunk. Larger transactions will fit in 1
1117        // or 2 chunks if the first chunk starts towards the end of a datagram. A small number of
1118        // transaction will have other protocol frames inserted in the middle. Empirically it's been
1119        // observed that 4 is the maximum number of chunks txs get split into.
1120        //
1121        // Bytes values are small, so overall the array takes only 128 bytes, and the "cost" of
1122        // overallocating a few bytes is negligible compared to the cost of having to do multiple
1123        // read_chunks() calls.
1124        let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
1125
1126        loop {
1127            // Read the next chunks, waiting up to `wait_for_chunk_timeout`. If we don't get chunks
1128            // before then, we assume the stream is dead. This can only happen if there's severe
1129            // packet loss or the peer stops sending for whatever reason.
1130            let n_chunks = match tokio::select! {
1131                chunk = tokio::time::timeout(
1132                    wait_for_chunk_timeout,
1133                    stream.read_chunks(&mut chunks)) => chunk,
1134
1135                // If the peer gets disconnected stop the task right away.
1136                _ = cancel.cancelled() => break,
1137            } {
1138                // read_chunk returned success
1139                Ok(Ok(chunk)) => chunk.unwrap_or(0),
1140                // read_chunk returned error
1141                Ok(Err(e)) => {
1142                    debug!("Received stream error: {:?}", e);
1143                    stats
1144                        .total_stream_read_errors
1145                        .fetch_add(1, Ordering::Relaxed);
1146                    break;
1147                }
1148                // timeout elapsed
1149                Err(_) => {
1150                    debug!("Timeout in receiving on stream");
1151                    stats
1152                        .total_stream_read_timeouts
1153                        .fetch_add(1, Ordering::Relaxed);
1154                    break;
1155                }
1156            };
1157
1158            match handle_chunks(
1159                // Bytes::clone() is a cheap atomic inc
1160                chunks.iter().take(n_chunks).cloned(),
1161                &mut accum,
1162                &packet_sender,
1163                &stats,
1164                peer_type,
1165            )
1166            .await
1167            {
1168                // The stream is finished, break out of the loop and close the stream.
1169                Ok(StreamState::Finished) => {
1170                    last_update.store(timing::timestamp(), Ordering::Relaxed);
1171                    break;
1172                }
1173                // The stream is still active, continue reading.
1174                Ok(StreamState::Receiving) => {}
1175                Err(_) => {
1176                    // Disconnect peers that send invalid streams.
1177                    connection.close(
1178                        CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
1179                        CONNECTION_CLOSE_REASON_INVALID_STREAM,
1180                    );
1181                    stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1182                    stream_load_ema.update_ema_if_needed();
1183                    break 'conn;
1184                }
1185            }
1186        }
1187
1188        stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1189        stream_load_ema.update_ema_if_needed();
1190    }
1191
1192    let stable_id = connection.stable_id();
1193    let removed_connection_count = connection_table.lock().await.remove_connection(
1194        ConnectionTableKey::new(remote_addr.ip(), remote_pubkey),
1195        remote_addr.port(),
1196        stable_id,
1197    );
1198    if removed_connection_count > 0 {
1199        stats
1200            .connection_removed
1201            .fetch_add(removed_connection_count, Ordering::Relaxed);
1202    } else {
1203        stats
1204            .connection_remove_failed
1205            .fetch_add(1, Ordering::Relaxed);
1206    }
1207    stats.total_connections.fetch_sub(1, Ordering::Relaxed);
1208}
1209
1210enum StreamState {
1211    // Stream is not finished, keep receiving chunks
1212    Receiving,
1213    // Stream is finished
1214    Finished,
1215}
1216
1217// Handle the chunks received from the stream. If the stream is finished, send the packet to the
1218// packet sender.
1219//
1220// Returns Err(()) if the stream is invalid.
1221async fn handle_chunks(
1222    chunks: impl ExactSizeIterator<Item = Bytes>,
1223    accum: &mut PacketAccumulator,
1224    packet_sender: &AsyncSender<PacketAccumulator>,
1225    stats: &StreamerStats,
1226    peer_type: ConnectionPeerType,
1227) -> Result<StreamState, ()> {
1228    let n_chunks = chunks.len();
1229    for chunk in chunks {
1230        accum.meta.size += chunk.len();
1231        if accum.meta.size > PACKET_DATA_SIZE {
1232            // The stream window size is set to PACKET_DATA_SIZE, so one individual chunk can
1233            // never exceed this size. A peer can send two chunks that together exceed the size
1234            // tho, in which case we report the error.
1235            stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
1236            debug!("invalid stream size {}", accum.meta.size);
1237            return Err(());
1238        }
1239        accum.chunks.push(chunk);
1240        if peer_type.is_staked() {
1241            stats
1242                .total_staked_chunks_received
1243                .fetch_add(1, Ordering::Relaxed);
1244        } else {
1245            stats
1246                .total_unstaked_chunks_received
1247                .fetch_add(1, Ordering::Relaxed);
1248        }
1249    }
1250
1251    // n_chunks == 0 marks the end of a stream
1252    if n_chunks != 0 {
1253        return Ok(StreamState::Receiving);
1254    }
1255
1256    if accum.chunks.is_empty() {
1257        debug!("stream is empty");
1258        stats
1259            .total_packet_batches_none
1260            .fetch_add(1, Ordering::Relaxed);
1261        return Err(());
1262    }
1263
1264    // done receiving chunks
1265    let bytes_sent = accum.meta.size;
1266    let chunks_sent = accum.chunks.len();
1267
1268    if let Err(err) = packet_sender.send(accum.clone()).await {
1269        stats
1270            .total_handle_chunk_to_packet_batcher_send_err
1271            .fetch_add(1, Ordering::Relaxed);
1272        trace!("packet batch send error {:?}", err);
1273    } else {
1274        stats
1275            .total_packets_sent_for_batching
1276            .fetch_add(1, Ordering::Relaxed);
1277        stats
1278            .total_bytes_sent_for_batching
1279            .fetch_add(bytes_sent, Ordering::Relaxed);
1280        stats
1281            .total_chunks_sent_for_batching
1282            .fetch_add(chunks_sent, Ordering::Relaxed);
1283
1284        match peer_type {
1285            ConnectionPeerType::Unstaked => {
1286                stats
1287                    .total_unstaked_packets_sent_for_batching
1288                    .fetch_add(1, Ordering::Relaxed);
1289            }
1290            ConnectionPeerType::Staked(_) => {
1291                stats
1292                    .total_staked_packets_sent_for_batching
1293                    .fetch_add(1, Ordering::Relaxed);
1294            }
1295        }
1296
1297        trace!("sent {} byte packet for batching", bytes_sent);
1298    }
1299
1300    Ok(StreamState::Finished)
1301}
1302
1303#[derive(Debug)]
1304struct ConnectionEntry {
1305    cancel: CancellationToken,
1306    peer_type: ConnectionPeerType,
1307    last_update: Arc<AtomicU64>,
1308    port: u16,
1309    // We do not explicitly use it, but its drop is triggered when ConnectionEntry is dropped.
1310    _client_connection_tracker: ClientConnectionTracker,
1311    connection: Option<Connection>,
1312    stream_counter: Arc<ConnectionStreamCounter>,
1313}
1314
1315impl ConnectionEntry {
1316    fn new(
1317        cancel: CancellationToken,
1318        peer_type: ConnectionPeerType,
1319        last_update: Arc<AtomicU64>,
1320        port: u16,
1321        client_connection_tracker: ClientConnectionTracker,
1322        connection: Option<Connection>,
1323        stream_counter: Arc<ConnectionStreamCounter>,
1324    ) -> Self {
1325        Self {
1326            cancel,
1327            peer_type,
1328            last_update,
1329            port,
1330            _client_connection_tracker: client_connection_tracker,
1331            connection,
1332            stream_counter,
1333        }
1334    }
1335
1336    fn last_update(&self) -> u64 {
1337        self.last_update.load(Ordering::Relaxed)
1338    }
1339
1340    fn stake(&self) -> u64 {
1341        match self.peer_type {
1342            ConnectionPeerType::Unstaked => 0,
1343            ConnectionPeerType::Staked(stake) => stake,
1344        }
1345    }
1346}
1347
1348impl Drop for ConnectionEntry {
1349    fn drop(&mut self) {
1350        if let Some(conn) = self.connection.take() {
1351            conn.close(
1352                CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1353                CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1354            );
1355        }
1356        self.cancel.cancel();
1357    }
1358}
1359
1360#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
1361enum ConnectionTableKey {
1362    IP(IpAddr),
1363    Pubkey(Pubkey),
1364}
1365
1366impl ConnectionTableKey {
1367    fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
1368        maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
1369            ConnectionTableKey::Pubkey(pubkey)
1370        })
1371    }
1372}
1373
1374// Map of IP to list of connection entries
1375struct ConnectionTable {
1376    table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry>>,
1377    total_size: usize,
1378}
1379
1380// Prune the connection which has the oldest update
1381// Return number pruned
1382impl ConnectionTable {
1383    fn new() -> Self {
1384        Self {
1385            table: IndexMap::default(),
1386            total_size: 0,
1387        }
1388    }
1389
1390    fn prune_oldest(&mut self, max_size: usize) -> usize {
1391        let mut num_pruned = 0;
1392        let key = |(_, connections): &(_, &Vec<_>)| {
1393            connections.iter().map(ConnectionEntry::last_update).min()
1394        };
1395        while self.total_size.saturating_sub(num_pruned) > max_size {
1396            match self.table.values().enumerate().min_by_key(key) {
1397                None => break,
1398                Some((index, connections)) => {
1399                    num_pruned += connections.len();
1400                    self.table.swap_remove_index(index);
1401                }
1402            }
1403        }
1404        self.total_size = self.total_size.saturating_sub(num_pruned);
1405        num_pruned
1406    }
1407
1408    // Randomly selects sample_size many connections, evicts the one with the
1409    // lowest stake, and returns the number of pruned connections.
1410    // If the stakes of all the sampled connections are higher than the
1411    // threshold_stake, rejects the pruning attempt, and returns 0.
1412    fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
1413        let num_pruned = std::iter::once(self.table.len())
1414            .filter(|&size| size > 0)
1415            .flat_map(|size| {
1416                let mut rng = thread_rng();
1417                repeat_with(move || rng.gen_range(0..size))
1418            })
1419            .map(|index| {
1420                let connection = self.table[index].first();
1421                let stake = connection.map(|connection: &ConnectionEntry| connection.stake());
1422                (index, stake)
1423            })
1424            .take(sample_size)
1425            .min_by_key(|&(_, stake)| stake)
1426            .filter(|&(_, stake)| stake < Some(threshold_stake))
1427            .and_then(|(index, _)| self.table.swap_remove_index(index))
1428            .map(|(_, connections)| connections.len())
1429            .unwrap_or_default();
1430        self.total_size = self.total_size.saturating_sub(num_pruned);
1431        num_pruned
1432    }
1433
1434    fn try_add_connection(
1435        &mut self,
1436        key: ConnectionTableKey,
1437        port: u16,
1438        client_connection_tracker: ClientConnectionTracker,
1439        connection: Option<Connection>,
1440        peer_type: ConnectionPeerType,
1441        last_update: u64,
1442        max_connections_per_peer: usize,
1443    ) -> Option<(
1444        Arc<AtomicU64>,
1445        CancellationToken,
1446        Arc<ConnectionStreamCounter>,
1447    )> {
1448        let connection_entry = self.table.entry(key).or_default();
1449        let has_connection_capacity = connection_entry
1450            .len()
1451            .checked_add(1)
1452            .map(|c| c <= max_connections_per_peer)
1453            .unwrap_or(false);
1454        if has_connection_capacity {
1455            let cancel = CancellationToken::new();
1456            let last_update = Arc::new(AtomicU64::new(last_update));
1457            let stream_counter = connection_entry
1458                .first()
1459                .map(|entry| entry.stream_counter.clone())
1460                .unwrap_or(Arc::new(ConnectionStreamCounter::new()));
1461            connection_entry.push(ConnectionEntry::new(
1462                cancel.clone(),
1463                peer_type,
1464                last_update.clone(),
1465                port,
1466                client_connection_tracker,
1467                connection,
1468                stream_counter.clone(),
1469            ));
1470            self.total_size += 1;
1471            Some((last_update, cancel, stream_counter))
1472        } else {
1473            if let Some(connection) = connection {
1474                connection.close(
1475                    CONNECTION_CLOSE_CODE_TOO_MANY.into(),
1476                    CONNECTION_CLOSE_REASON_TOO_MANY,
1477                );
1478            }
1479            None
1480        }
1481    }
1482
1483    // Returns number of connections that were removed
1484    fn remove_connection(&mut self, key: ConnectionTableKey, port: u16, stable_id: usize) -> usize {
1485        if let Entry::Occupied(mut e) = self.table.entry(key) {
1486            let e_ref = e.get_mut();
1487            let old_size = e_ref.len();
1488
1489            e_ref.retain(|connection_entry| {
1490                // Retain the connection entry if the port is different, or if the connection's
1491                // stable_id doesn't match the provided stable_id.
1492                // (Some unit tests do not fill in a valid connection in the table. To support that,
1493                // if the connection is none, the stable_id check is ignored. i.e. if the port matches,
1494                // the connection gets removed)
1495                connection_entry.port != port
1496                    || connection_entry
1497                        .connection
1498                        .as_ref()
1499                        .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
1500                        .is_some()
1501            });
1502            let new_size = e_ref.len();
1503            if e_ref.is_empty() {
1504                e.swap_remove_entry();
1505            }
1506            let connections_removed = old_size.saturating_sub(new_size);
1507            self.total_size = self.total_size.saturating_sub(connections_removed);
1508            connections_removed
1509        } else {
1510            0
1511        }
1512    }
1513}
1514
1515struct EndpointAccept<'a> {
1516    endpoint: usize,
1517    accept: Accept<'a>,
1518}
1519
1520impl<'a> Future for EndpointAccept<'a> {
1521    type Output = (Option<quinn::Incoming>, usize);
1522
1523    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
1524        let i = self.endpoint;
1525        // Safety:
1526        // self is pinned and accept is a field so it can't get moved out. See safety docs of
1527        // map_unchecked_mut.
1528        unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
1529            .poll(cx)
1530            .map(|r| (r, i))
1531    }
1532}
1533
1534#[cfg(test)]
1535pub mod test {
1536    use {
1537        super::*,
1538        crate::{
1539            nonblocking::{
1540                quic::compute_max_allowed_uni_streams,
1541                testing_utilities::{
1542                    get_client_config, make_client_endpoint, setup_quic_server,
1543                    SpawnTestServerResult, TestServerConfig,
1544                },
1545            },
1546            quic::{MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS},
1547        },
1548        assert_matches::assert_matches,
1549        async_channel::unbounded as async_unbounded,
1550        crossbeam_channel::{unbounded, Receiver},
1551        quinn::{ApplicationClose, ConnectionError},
1552        solana_sdk::{net::DEFAULT_TPU_COALESCE, signature::Keypair, signer::Signer},
1553        std::collections::HashMap,
1554        tokio::time::sleep,
1555    };
1556
1557    pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
1558        let conn1 = make_client_endpoint(&server_address, None).await;
1559        let total = 30;
1560        for i in 0..total {
1561            let mut s1 = conn1.open_uni().await.unwrap();
1562            s1.write_all(&[0u8]).await.unwrap();
1563            s1.finish().unwrap();
1564            info!("done {}", i);
1565            sleep(Duration::from_millis(1000)).await;
1566        }
1567        let mut received = 0;
1568        loop {
1569            if let Ok(_x) = receiver.try_recv() {
1570                received += 1;
1571                info!("got {}", received);
1572            } else {
1573                sleep(Duration::from_millis(500)).await;
1574            }
1575            if received >= total {
1576                break;
1577            }
1578        }
1579    }
1580
1581    pub async fn check_block_multiple_connections(server_address: SocketAddr) {
1582        let conn1 = make_client_endpoint(&server_address, None).await;
1583        let conn2 = make_client_endpoint(&server_address, None).await;
1584        let mut s1 = conn1.open_uni().await.unwrap();
1585        let s2 = conn2.open_uni().await;
1586        if let Ok(mut s2) = s2 {
1587            s1.write_all(&[0u8]).await.unwrap();
1588            s1.finish().unwrap();
1589            // Send enough data to create more than 1 chunks.
1590            // The first will try to open the connection (which should fail).
1591            // The following chunks will enable the detection of connection failure.
1592            let data = vec![1u8; PACKET_DATA_SIZE * 2];
1593            s2.write_all(&data)
1594                .await
1595                .expect_err("shouldn't be able to open 2 connections");
1596        } else {
1597            // It has been noticed if there is already connection open against the server, this open_uni can fail
1598            // with ApplicationClosed(ApplicationClose) error due to CONNECTION_CLOSE_CODE_TOO_MANY before writing to
1599            // the stream -- expect it.
1600            assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
1601        }
1602    }
1603
1604    pub async fn check_multiple_streams(
1605        receiver: Receiver<PacketBatch>,
1606        server_address: SocketAddr,
1607    ) {
1608        let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1609        let conn2 = Arc::new(make_client_endpoint(&server_address, None).await);
1610        let mut num_expected_packets = 0;
1611        for i in 0..10 {
1612            info!("sending: {}", i);
1613            let c1 = conn1.clone();
1614            let c2 = conn2.clone();
1615            let mut s1 = c1.open_uni().await.unwrap();
1616            let mut s2 = c2.open_uni().await.unwrap();
1617            s1.write_all(&[0u8]).await.unwrap();
1618            s1.finish().unwrap();
1619            s2.write_all(&[0u8]).await.unwrap();
1620            s2.finish().unwrap();
1621            num_expected_packets += 2;
1622            sleep(Duration::from_millis(200)).await;
1623        }
1624        let mut all_packets = vec![];
1625        let now = Instant::now();
1626        let mut total_packets = 0;
1627        while now.elapsed().as_secs() < 10 {
1628            if let Ok(packets) = receiver.try_recv() {
1629                total_packets += packets.len();
1630                all_packets.push(packets)
1631            } else {
1632                sleep(Duration::from_secs(1)).await;
1633            }
1634            if total_packets == num_expected_packets {
1635                break;
1636            }
1637        }
1638        for batch in all_packets {
1639            for p in batch.iter() {
1640                assert_eq!(p.meta().size, 1);
1641            }
1642        }
1643        assert_eq!(total_packets, num_expected_packets);
1644    }
1645
1646    pub async fn check_multiple_writes(
1647        receiver: Receiver<PacketBatch>,
1648        server_address: SocketAddr,
1649        client_keypair: Option<&Keypair>,
1650    ) {
1651        let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1652
1653        // Send a full size packet with single byte writes.
1654        let num_bytes = PACKET_DATA_SIZE;
1655        let num_expected_packets = 1;
1656        let mut s1 = conn1.open_uni().await.unwrap();
1657        for _ in 0..num_bytes {
1658            s1.write_all(&[0u8]).await.unwrap();
1659        }
1660        s1.finish().unwrap();
1661
1662        let mut all_packets = vec![];
1663        let now = Instant::now();
1664        let mut total_packets = 0;
1665        while now.elapsed().as_secs() < 5 {
1666            // We're running in an async environment, we (almost) never
1667            // want to block
1668            if let Ok(packets) = receiver.try_recv() {
1669                total_packets += packets.len();
1670                all_packets.push(packets)
1671            } else {
1672                sleep(Duration::from_secs(1)).await;
1673            }
1674            if total_packets >= num_expected_packets {
1675                break;
1676            }
1677        }
1678        for batch in all_packets {
1679            for p in batch.iter() {
1680                assert_eq!(p.meta().size, num_bytes);
1681            }
1682        }
1683        assert_eq!(total_packets, num_expected_packets);
1684    }
1685
1686    pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
1687        let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1688
1689        // Send a full size packet with single byte writes.
1690        if let Ok(mut s1) = conn1.open_uni().await {
1691            for _ in 0..PACKET_DATA_SIZE {
1692                // Ignoring any errors here. s1.finish() will test the error condition
1693                s1.write_all(&[0u8]).await.unwrap_or_default();
1694            }
1695            s1.finish().unwrap_or_default();
1696            s1.stopped().await.unwrap_err();
1697        }
1698    }
1699
1700    #[tokio::test(flavor = "multi_thread")]
1701    async fn test_quic_server_exit() {
1702        let SpawnTestServerResult {
1703            join_handle,
1704            exit,
1705            receiver: _,
1706            server_address: _,
1707            stats: _,
1708        } = setup_quic_server(None, TestServerConfig::default());
1709        exit.store(true, Ordering::Relaxed);
1710        join_handle.await.unwrap();
1711    }
1712
1713    #[tokio::test(flavor = "multi_thread")]
1714    async fn test_quic_timeout() {
1715        solana_logger::setup();
1716        let SpawnTestServerResult {
1717            join_handle,
1718            exit,
1719            receiver,
1720            server_address,
1721            stats: _,
1722        } = setup_quic_server(None, TestServerConfig::default());
1723
1724        check_timeout(receiver, server_address).await;
1725        exit.store(true, Ordering::Relaxed);
1726        join_handle.await.unwrap();
1727    }
1728
1729    #[tokio::test(flavor = "multi_thread")]
1730    async fn test_packet_batcher() {
1731        solana_logger::setup();
1732        let (pkt_batch_sender, pkt_batch_receiver) = unbounded();
1733        let (ptk_sender, pkt_receiver) = async_unbounded();
1734        let exit = Arc::new(AtomicBool::new(false));
1735        let stats = Arc::new(StreamerStats::default());
1736
1737        let handle = tokio::spawn(packet_batch_sender(
1738            pkt_batch_sender,
1739            pkt_receiver,
1740            exit.clone(),
1741            stats,
1742            DEFAULT_TPU_COALESCE,
1743        ));
1744
1745        let num_packets = 1000;
1746
1747        for _i in 0..num_packets {
1748            let mut meta = Meta::default();
1749            let bytes = Bytes::from("Hello world");
1750            let size = bytes.len();
1751            meta.size = size;
1752            let packet_accum = PacketAccumulator {
1753                meta,
1754                chunks: smallvec::smallvec![bytes],
1755                start_time: Instant::now(),
1756            };
1757            ptk_sender.send(packet_accum).await.unwrap();
1758        }
1759        let mut i = 0;
1760        let start = Instant::now();
1761        while i < num_packets && start.elapsed().as_secs() < 2 {
1762            if let Ok(batch) = pkt_batch_receiver.try_recv() {
1763                i += batch.len();
1764            } else {
1765                sleep(Duration::from_millis(1)).await;
1766            }
1767        }
1768        assert_eq!(i, num_packets);
1769        exit.store(true, Ordering::Relaxed);
1770        // Explicit drop to wake up packet_batch_sender
1771        drop(ptk_sender);
1772        handle.await.unwrap();
1773    }
1774
1775    #[tokio::test(flavor = "multi_thread")]
1776    async fn test_quic_stream_timeout() {
1777        solana_logger::setup();
1778        let SpawnTestServerResult {
1779            join_handle,
1780            exit,
1781            receiver: _,
1782            server_address,
1783            stats,
1784        } = setup_quic_server(None, TestServerConfig::default());
1785
1786        let conn1 = make_client_endpoint(&server_address, None).await;
1787        assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1788        assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1789
1790        // Send one byte to start the stream
1791        let mut s1 = conn1.open_uni().await.unwrap();
1792        s1.write_all(&[0u8]).await.unwrap_or_default();
1793
1794        // Wait long enough for the stream to timeout in receiving chunks
1795        let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
1796        sleep(sleep_time).await;
1797
1798        // Test that the stream was created, but timed out in read
1799        assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1800        assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1801
1802        // Test that more writes to the stream will fail (i.e. the stream is no longer writable
1803        // after the timeouts)
1804        assert!(s1.write_all(&[0u8]).await.is_err());
1805
1806        exit.store(true, Ordering::Relaxed);
1807        join_handle.await.unwrap();
1808    }
1809
1810    #[tokio::test(flavor = "multi_thread")]
1811    async fn test_quic_server_block_multiple_connections() {
1812        solana_logger::setup();
1813        let SpawnTestServerResult {
1814            join_handle,
1815            exit,
1816            receiver: _,
1817            server_address,
1818            stats: _,
1819        } = setup_quic_server(None, TestServerConfig::default());
1820        check_block_multiple_connections(server_address).await;
1821        exit.store(true, Ordering::Relaxed);
1822        join_handle.await.unwrap();
1823    }
1824
1825    #[tokio::test(flavor = "multi_thread")]
1826    async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
1827        solana_logger::setup();
1828
1829        let SpawnTestServerResult {
1830            join_handle,
1831            exit,
1832            receiver: _,
1833            server_address,
1834            stats,
1835        } = setup_quic_server(
1836            None,
1837            TestServerConfig {
1838                max_connections_per_peer: 2,
1839                ..Default::default()
1840            },
1841        );
1842
1843        let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
1844        let mut endpoint = quinn::Endpoint::new(
1845            EndpointConfig::default(),
1846            None,
1847            client_socket,
1848            Arc::new(TokioRuntime),
1849        )
1850        .unwrap();
1851        let default_keypair = Keypair::new();
1852        endpoint.set_default_client_config(get_client_config(&default_keypair));
1853        let conn1 = endpoint
1854            .connect(server_address, "localhost")
1855            .expect("Failed in connecting")
1856            .await
1857            .expect("Failed in waiting");
1858
1859        let conn2 = endpoint
1860            .connect(server_address, "localhost")
1861            .expect("Failed in connecting")
1862            .await
1863            .expect("Failed in waiting");
1864
1865        let mut s1 = conn1.open_uni().await.unwrap();
1866        s1.write_all(&[0u8]).await.unwrap();
1867        s1.finish().unwrap();
1868
1869        let mut s2 = conn2.open_uni().await.unwrap();
1870        conn1.close(
1871            CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1872            CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1873        );
1874
1875        let start = Instant::now();
1876        while stats.connection_removed.load(Ordering::Relaxed) != 1 {
1877            debug!("First connection not removed yet");
1878            sleep(Duration::from_millis(10)).await;
1879        }
1880        assert!(start.elapsed().as_secs() < 1);
1881
1882        s2.write_all(&[0u8]).await.unwrap();
1883        s2.finish().unwrap();
1884
1885        conn2.close(
1886            CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1887            CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1888        );
1889
1890        let start = Instant::now();
1891        while stats.connection_removed.load(Ordering::Relaxed) != 2 {
1892            debug!("Second connection not removed yet");
1893            sleep(Duration::from_millis(10)).await;
1894        }
1895        assert!(start.elapsed().as_secs() < 1);
1896
1897        exit.store(true, Ordering::Relaxed);
1898        join_handle.await.unwrap();
1899    }
1900
1901    #[tokio::test(flavor = "multi_thread")]
1902    async fn test_quic_server_multiple_writes() {
1903        solana_logger::setup();
1904        let SpawnTestServerResult {
1905            join_handle,
1906            exit,
1907            receiver,
1908            server_address,
1909            stats: _,
1910        } = setup_quic_server(None, TestServerConfig::default());
1911        check_multiple_writes(receiver, server_address, None).await;
1912        exit.store(true, Ordering::Relaxed);
1913        join_handle.await.unwrap();
1914    }
1915
1916    #[tokio::test(flavor = "multi_thread")]
1917    async fn test_quic_server_staked_connection_removal() {
1918        solana_logger::setup();
1919
1920        let client_keypair = Keypair::new();
1921        let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
1922        let staked_nodes = StakedNodes::new(
1923            Arc::new(stakes),
1924            HashMap::<Pubkey, u64>::default(), // overrides
1925        );
1926        let SpawnTestServerResult {
1927            join_handle,
1928            exit,
1929            receiver,
1930            server_address,
1931            stats,
1932        } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1933        check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1934        exit.store(true, Ordering::Relaxed);
1935        join_handle.await.unwrap();
1936        sleep(Duration::from_millis(100)).await;
1937        assert_eq!(
1938            stats
1939                .connection_added_from_unstaked_peer
1940                .load(Ordering::Relaxed),
1941            0
1942        );
1943        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1944        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1945    }
1946
1947    #[tokio::test(flavor = "multi_thread")]
1948    async fn test_quic_server_zero_staked_connection_removal() {
1949        // In this test, the client has a pubkey, but is not in stake table.
1950        solana_logger::setup();
1951
1952        let client_keypair = Keypair::new();
1953        let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
1954        let staked_nodes = StakedNodes::new(
1955            Arc::new(stakes),
1956            HashMap::<Pubkey, u64>::default(), // overrides
1957        );
1958        let SpawnTestServerResult {
1959            join_handle,
1960            exit,
1961            receiver,
1962            server_address,
1963            stats,
1964        } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1965        check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1966        exit.store(true, Ordering::Relaxed);
1967        join_handle.await.unwrap();
1968        sleep(Duration::from_millis(100)).await;
1969        assert_eq!(
1970            stats
1971                .connection_added_from_staked_peer
1972                .load(Ordering::Relaxed),
1973            0
1974        );
1975        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1976        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1977    }
1978
1979    #[tokio::test(flavor = "multi_thread")]
1980    async fn test_quic_server_unstaked_connection_removal() {
1981        solana_logger::setup();
1982        let SpawnTestServerResult {
1983            join_handle,
1984            exit,
1985            receiver,
1986            server_address,
1987            stats,
1988        } = setup_quic_server(None, TestServerConfig::default());
1989        check_multiple_writes(receiver, server_address, None).await;
1990        exit.store(true, Ordering::Relaxed);
1991        join_handle.await.unwrap();
1992        sleep(Duration::from_millis(100)).await;
1993        assert_eq!(
1994            stats
1995                .connection_added_from_staked_peer
1996                .load(Ordering::Relaxed),
1997            0
1998        );
1999        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
2000        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
2001    }
2002
2003    #[tokio::test(flavor = "multi_thread")]
2004    async fn test_quic_server_unstaked_node_connect_failure() {
2005        solana_logger::setup();
2006        let s = UdpSocket::bind("127.0.0.1:0").unwrap();
2007        let exit = Arc::new(AtomicBool::new(false));
2008        let (sender, _) = unbounded();
2009        let keypair = Keypair::new();
2010        let server_address = s.local_addr().unwrap();
2011        let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
2012        let SpawnNonBlockingServerResult {
2013            endpoints: _,
2014            stats: _,
2015            thread: t,
2016            max_concurrent_connections: _,
2017        } = spawn_server(
2018            "quic_streamer_test",
2019            s,
2020            &keypair,
2021            sender,
2022            exit.clone(),
2023            1,
2024            staked_nodes,
2025            MAX_STAKED_CONNECTIONS,
2026            0, // Do not allow any connection from unstaked clients/nodes
2027            DEFAULT_MAX_STREAMS_PER_MS,
2028            DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE,
2029            DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
2030            DEFAULT_TPU_COALESCE,
2031        )
2032        .unwrap();
2033
2034        check_unstaked_node_connect_failure(server_address).await;
2035        exit.store(true, Ordering::Relaxed);
2036        t.await.unwrap();
2037    }
2038
2039    #[tokio::test(flavor = "multi_thread")]
2040    async fn test_quic_server_multiple_streams() {
2041        solana_logger::setup();
2042        let s = UdpSocket::bind("127.0.0.1:0").unwrap();
2043        let exit = Arc::new(AtomicBool::new(false));
2044        let (sender, receiver) = unbounded();
2045        let keypair = Keypair::new();
2046        let server_address = s.local_addr().unwrap();
2047        let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
2048        let SpawnNonBlockingServerResult {
2049            endpoints: _,
2050            stats,
2051            thread: t,
2052            max_concurrent_connections: _,
2053        } = spawn_server(
2054            "quic_streamer_test",
2055            s,
2056            &keypair,
2057            sender,
2058            exit.clone(),
2059            2,
2060            staked_nodes,
2061            MAX_STAKED_CONNECTIONS,
2062            MAX_UNSTAKED_CONNECTIONS,
2063            DEFAULT_MAX_STREAMS_PER_MS,
2064            DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE,
2065            DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
2066            DEFAULT_TPU_COALESCE,
2067        )
2068        .unwrap();
2069
2070        check_multiple_streams(receiver, server_address).await;
2071        assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
2072        assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
2073        assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
2074        assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2075        exit.store(true, Ordering::Relaxed);
2076        t.await.unwrap();
2077        assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
2078        assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2079    }
2080
2081    #[test]
2082    fn test_prune_table_with_ip() {
2083        use std::net::Ipv4Addr;
2084        solana_logger::setup();
2085        let mut table = ConnectionTable::new();
2086        let mut num_entries = 5;
2087        let max_connections_per_peer = 10;
2088        let sockets: Vec<_> = (0..num_entries)
2089            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2090            .collect();
2091        let stats = Arc::new(StreamerStats::default());
2092        for (i, socket) in sockets.iter().enumerate() {
2093            table
2094                .try_add_connection(
2095                    ConnectionTableKey::IP(socket.ip()),
2096                    socket.port(),
2097                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2098                    None,
2099                    ConnectionPeerType::Unstaked,
2100                    i as u64,
2101                    max_connections_per_peer,
2102                )
2103                .unwrap();
2104        }
2105        num_entries += 1;
2106        table
2107            .try_add_connection(
2108                ConnectionTableKey::IP(sockets[0].ip()),
2109                sockets[0].port(),
2110                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2111                None,
2112                ConnectionPeerType::Unstaked,
2113                5,
2114                max_connections_per_peer,
2115            )
2116            .unwrap();
2117
2118        let new_size = 3;
2119        let pruned = table.prune_oldest(new_size);
2120        assert_eq!(pruned, num_entries as usize - new_size);
2121        for v in table.table.values() {
2122            for x in v {
2123                assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
2124            }
2125        }
2126        assert_eq!(table.table.len(), new_size);
2127        assert_eq!(table.total_size, new_size);
2128        for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
2129            table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2130        }
2131        assert_eq!(table.total_size, 0);
2132        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2133    }
2134
2135    #[test]
2136    fn test_prune_table_with_unique_pubkeys() {
2137        solana_logger::setup();
2138        let mut table = ConnectionTable::new();
2139
2140        // We should be able to add more entries than max_connections_per_peer, since each entry is
2141        // from a different peer pubkey.
2142        let num_entries = 15;
2143        let max_connections_per_peer = 10;
2144        let stats = Arc::new(StreamerStats::default());
2145
2146        let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
2147        for (i, pubkey) in pubkeys.iter().enumerate() {
2148            table
2149                .try_add_connection(
2150                    ConnectionTableKey::Pubkey(*pubkey),
2151                    0,
2152                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2153                    None,
2154                    ConnectionPeerType::Unstaked,
2155                    i as u64,
2156                    max_connections_per_peer,
2157                )
2158                .unwrap();
2159        }
2160
2161        let new_size = 3;
2162        let pruned = table.prune_oldest(new_size);
2163        assert_eq!(pruned, num_entries as usize - new_size);
2164        assert_eq!(table.table.len(), new_size);
2165        assert_eq!(table.total_size, new_size);
2166        for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
2167            table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
2168        }
2169        assert_eq!(table.total_size, 0);
2170        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2171    }
2172
2173    #[test]
2174    fn test_prune_table_with_non_unique_pubkeys() {
2175        solana_logger::setup();
2176        let mut table = ConnectionTable::new();
2177
2178        let max_connections_per_peer = 10;
2179        let pubkey = Pubkey::new_unique();
2180        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2181
2182        (0..max_connections_per_peer).for_each(|i| {
2183            table
2184                .try_add_connection(
2185                    ConnectionTableKey::Pubkey(pubkey),
2186                    0,
2187                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2188                    None,
2189                    ConnectionPeerType::Unstaked,
2190                    i as u64,
2191                    max_connections_per_peer,
2192                )
2193                .unwrap();
2194        });
2195
2196        // We should NOT be able to add more entries than max_connections_per_peer, since we are
2197        // using the same peer pubkey.
2198        assert!(table
2199            .try_add_connection(
2200                ConnectionTableKey::Pubkey(pubkey),
2201                0,
2202                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2203                None,
2204                ConnectionPeerType::Unstaked,
2205                10,
2206                max_connections_per_peer,
2207            )
2208            .is_none());
2209
2210        // We should be able to add an entry from another peer pubkey
2211        let num_entries = max_connections_per_peer + 1;
2212        let pubkey2 = Pubkey::new_unique();
2213        assert!(table
2214            .try_add_connection(
2215                ConnectionTableKey::Pubkey(pubkey2),
2216                0,
2217                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2218                None,
2219                ConnectionPeerType::Unstaked,
2220                10,
2221                max_connections_per_peer,
2222            )
2223            .is_some());
2224
2225        assert_eq!(table.total_size, num_entries);
2226
2227        let new_max_size = 3;
2228        let pruned = table.prune_oldest(new_max_size);
2229        assert!(pruned >= num_entries - new_max_size);
2230        assert!(table.table.len() <= new_max_size);
2231        assert!(table.total_size <= new_max_size);
2232
2233        table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
2234        assert_eq!(table.total_size, 0);
2235        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2236    }
2237
2238    #[test]
2239    fn test_prune_table_random() {
2240        use std::net::Ipv4Addr;
2241        solana_logger::setup();
2242        let mut table = ConnectionTable::new();
2243        let num_entries = 5;
2244        let max_connections_per_peer = 10;
2245        let sockets: Vec<_> = (0..num_entries)
2246            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2247            .collect();
2248        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2249
2250        for (i, socket) in sockets.iter().enumerate() {
2251            table
2252                .try_add_connection(
2253                    ConnectionTableKey::IP(socket.ip()),
2254                    socket.port(),
2255                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2256                    None,
2257                    ConnectionPeerType::Staked((i + 1) as u64),
2258                    i as u64,
2259                    max_connections_per_peer,
2260                )
2261                .unwrap();
2262        }
2263
2264        // Try pruninng with threshold stake less than all the entries in the table
2265        // It should fail to prune (i.e. return 0 number of pruned entries)
2266        let pruned = table.prune_random(/*sample_size:*/ 2, /*threshold_stake:*/ 0);
2267        assert_eq!(pruned, 0);
2268
2269        // Try pruninng with threshold stake higher than all the entries in the table
2270        // It should succeed to prune (i.e. return 1 number of pruned entries)
2271        let pruned = table.prune_random(
2272            2,                      // sample_size
2273            num_entries as u64 + 1, // threshold_stake
2274        );
2275        assert_eq!(pruned, 1);
2276        // We had 5 connections and pruned 1, we should have 4 left
2277        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
2278    }
2279
2280    #[test]
2281    fn test_remove_connections() {
2282        use std::net::Ipv4Addr;
2283        solana_logger::setup();
2284        let mut table = ConnectionTable::new();
2285        let num_ips = 5;
2286        let max_connections_per_peer = 10;
2287        let mut sockets: Vec<_> = (0..num_ips)
2288            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2289            .collect();
2290        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2291
2292        for (i, socket) in sockets.iter().enumerate() {
2293            table
2294                .try_add_connection(
2295                    ConnectionTableKey::IP(socket.ip()),
2296                    socket.port(),
2297                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2298                    None,
2299                    ConnectionPeerType::Unstaked,
2300                    (i * 2) as u64,
2301                    max_connections_per_peer,
2302                )
2303                .unwrap();
2304
2305            table
2306                .try_add_connection(
2307                    ConnectionTableKey::IP(socket.ip()),
2308                    socket.port(),
2309                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2310                    None,
2311                    ConnectionPeerType::Unstaked,
2312                    (i * 2 + 1) as u64,
2313                    max_connections_per_peer,
2314                )
2315                .unwrap();
2316        }
2317
2318        let single_connection_addr =
2319            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
2320        table
2321            .try_add_connection(
2322                ConnectionTableKey::IP(single_connection_addr.ip()),
2323                single_connection_addr.port(),
2324                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2325                None,
2326                ConnectionPeerType::Unstaked,
2327                (num_ips * 2) as u64,
2328                max_connections_per_peer,
2329            )
2330            .unwrap();
2331
2332        let zero_connection_addr =
2333            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
2334
2335        sockets.push(single_connection_addr);
2336        sockets.push(zero_connection_addr);
2337
2338        for socket in sockets.iter() {
2339            table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2340        }
2341        assert_eq!(table.total_size, 0);
2342        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2343    }
2344
2345    #[test]
2346
2347    fn test_max_allowed_uni_streams() {
2348        assert_eq!(
2349            compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0),
2350            QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2351        );
2352        assert_eq!(
2353            compute_max_allowed_uni_streams(ConnectionPeerType::Staked(10), 0),
2354            QUIC_MIN_STAKED_CONCURRENT_STREAMS
2355        );
2356        let delta =
2357            (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
2358        assert_eq!(
2359            compute_max_allowed_uni_streams(ConnectionPeerType::Staked(1000), 10000),
2360            QUIC_MAX_STAKED_CONCURRENT_STREAMS,
2361        );
2362        assert_eq!(
2363            compute_max_allowed_uni_streams(ConnectionPeerType::Staked(100), 10000),
2364            ((delta / (100_f64)) as usize + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
2365                .min(QUIC_MAX_STAKED_CONCURRENT_STREAMS)
2366        );
2367        assert_eq!(
2368            compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10000),
2369            QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2370        );
2371    }
2372
2373    #[test]
2374    fn test_cacluate_receive_window_ratio_for_staked_node() {
2375        let mut max_stake = 10000;
2376        let mut min_stake = 0;
2377        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, min_stake);
2378        assert_eq!(ratio, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO);
2379
2380        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2381        let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
2382        assert_eq!(ratio, max_ratio);
2383
2384        let ratio =
2385            compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake / 2);
2386        let average_ratio =
2387            (QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO + QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO) / 2;
2388        assert_eq!(ratio, average_ratio);
2389
2390        max_stake = 10000;
2391        min_stake = 10000;
2392        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2393        assert_eq!(ratio, max_ratio);
2394
2395        max_stake = 0;
2396        min_stake = 0;
2397        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2398        assert_eq!(ratio, max_ratio);
2399
2400        max_stake = 1000;
2401        min_stake = 10;
2402        let ratio =
2403            compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake + 10);
2404        assert_eq!(ratio, max_ratio);
2405    }
2406
2407    #[tokio::test(flavor = "multi_thread")]
2408    async fn test_throttling_check_no_packet_drop() {
2409        solana_logger::setup_with_default_filter();
2410
2411        let SpawnTestServerResult {
2412            join_handle,
2413            exit,
2414            receiver,
2415            server_address,
2416            stats,
2417        } = setup_quic_server(None, TestServerConfig::default());
2418
2419        let client_connection = make_client_endpoint(&server_address, None).await;
2420
2421        // unstaked connection can handle up to 100tps, so we should send in ~1s.
2422        let expected_num_txs = 100;
2423        let start_time = tokio::time::Instant::now();
2424        for i in 0..expected_num_txs {
2425            let mut send_stream = client_connection.open_uni().await.unwrap();
2426            let data = format!("{i}").into_bytes();
2427            send_stream.write_all(&data).await.unwrap();
2428            send_stream.finish().unwrap();
2429        }
2430        let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
2431        info!("Elapsed sending: {elapsed_sending}");
2432
2433        // check that delivered all of them
2434        let start_time = tokio::time::Instant::now();
2435        let mut num_txs_received = 0;
2436        while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
2437            if let Ok(packets) = receiver.try_recv() {
2438                num_txs_received += packets.len();
2439            } else {
2440                sleep(Duration::from_millis(100)).await;
2441            }
2442        }
2443        assert_eq!(expected_num_txs, num_txs_received);
2444
2445        // stop it
2446        exit.store(true, Ordering::Relaxed);
2447        join_handle.await.unwrap();
2448
2449        assert_eq!(
2450            stats.total_new_streams.load(Ordering::Relaxed),
2451            expected_num_txs
2452        );
2453        assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
2454    }
2455
2456    #[test]
2457    fn test_client_connection_tracker() {
2458        let stats = Arc::new(StreamerStats::default());
2459        let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
2460        assert!(tracker_1.is_ok());
2461        assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
2462        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
2463        // dropping the connection, concurrent connections should become 0
2464        drop(tracker_1);
2465        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2466    }
2467
2468    #[tokio::test(flavor = "multi_thread")]
2469    async fn test_client_connection_close_invalid_stream() {
2470        let SpawnTestServerResult {
2471            join_handle,
2472            server_address,
2473            stats,
2474            exit,
2475            ..
2476        } = setup_quic_server(None, TestServerConfig::default());
2477
2478        let client_connection = make_client_endpoint(&server_address, None).await;
2479
2480        let mut send_stream = client_connection.open_uni().await.unwrap();
2481        send_stream
2482            .write_all(&[42; PACKET_DATA_SIZE + 1])
2483            .await
2484            .unwrap();
2485        match client_connection.closed().await {
2486            ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
2487                assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
2488                assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
2489            }
2490            _ => panic!("unexpected close"),
2491        }
2492        assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
2493        exit.store(true, Ordering::Relaxed);
2494        join_handle.await.unwrap();
2495    }
2496}