solana_streamer/nonblocking/
quic.rs

1use {
2    crate::{
3        nonblocking::{
4            connection_rate_limiter::{ConnectionRateLimiter, TotalConnectionRateLimiter},
5            stream_throttle::{
6                ConnectionStreamCounter, StakedStreamLoadEMA, STREAM_THROTTLING_INTERVAL,
7                STREAM_THROTTLING_INTERVAL_MS,
8            },
9        },
10        quic::{configure_server, QuicServerError, QuicServerParams, StreamerStats},
11        streamer::StakedNodes,
12    },
13    async_channel::{bounded as async_bounded, Receiver as AsyncReceiver, Sender as AsyncSender},
14    bytes::Bytes,
15    crossbeam_channel::Sender,
16    futures::{stream::FuturesUnordered, Future, StreamExt as _},
17    indexmap::map::{Entry, IndexMap},
18    percentage::Percentage,
19    quinn::{Accept, Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime, VarInt},
20    quinn_proto::VarIntBoundsExceeded,
21    rand::{thread_rng, Rng},
22    smallvec::SmallVec,
23    solana_keypair::Keypair,
24    solana_measure::measure::Measure,
25    solana_packet::{Meta, PACKET_DATA_SIZE},
26    solana_perf::packet::{PacketBatch, PacketBatchRecycler, PACKETS_PER_BATCH},
27    solana_pubkey::Pubkey,
28    solana_quic_definitions::{
29        QUIC_CONNECTION_HANDSHAKE_TIMEOUT, QUIC_MAX_STAKED_CONCURRENT_STREAMS,
30        QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO, QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
31        QUIC_MIN_STAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO,
32        QUIC_TOTAL_STAKED_CONCURRENT_STREAMS, QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO,
33    },
34    solana_signature::Signature,
35    solana_time_utils as timing,
36    solana_tls_utils::get_pubkey_from_tls_certificate,
37    solana_transaction_metrics_tracker::signature_if_should_track_packet,
38    std::{
39        array,
40        fmt,
41        iter::repeat_with,
42        net::{IpAddr, SocketAddr, UdpSocket},
43        pin::Pin,
44        // CAUTION: be careful not to introduce any awaits while holding an RwLock.
45        sync::{
46            atomic::{AtomicBool, AtomicU64, Ordering},
47            Arc, RwLock,
48        },
49        task::Poll,
50        time::{Duration, Instant},
51    },
52    tokio::{
53        // CAUTION: It's kind of sketch that we're mixing async and sync locks (see the RwLock above).
54        // This is done so that sync code can also access the stake table.
55        // Make sure we don't hold a sync lock across an await - including the await to
56        // lock an async Mutex. This does not happen now and should not happen as long as we
57        // don't hold an async Mutex and sync RwLock at the same time (currently true)
58        // but if we do, the scope of the RwLock must always be a subset of the async Mutex
59        // (i.e. lock order is always async Mutex -> RwLock). Also, be careful not to
60        // introduce any other awaits while holding the RwLock.
61        select,
62        sync::{Mutex, MutexGuard},
63        task::JoinHandle,
64        time::{sleep, timeout},
65    },
66    tokio_util::sync::CancellationToken,
67};
68
69pub const DEFAULT_WAIT_FOR_CHUNK_TIMEOUT: Duration = Duration::from_secs(2);
70
71pub const ALPN_TPU_PROTOCOL_ID: &[u8] = b"solana-tpu";
72
73const CONNECTION_CLOSE_CODE_DROPPED_ENTRY: u32 = 1;
74const CONNECTION_CLOSE_REASON_DROPPED_ENTRY: &[u8] = b"dropped";
75
76const CONNECTION_CLOSE_CODE_DISALLOWED: u32 = 2;
77const CONNECTION_CLOSE_REASON_DISALLOWED: &[u8] = b"disallowed";
78
79const CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT: u32 = 3;
80const CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT: &[u8] = b"exceed_max_stream_count";
81
82const CONNECTION_CLOSE_CODE_TOO_MANY: u32 = 4;
83const CONNECTION_CLOSE_REASON_TOO_MANY: &[u8] = b"too_many";
84
85const CONNECTION_CLOSE_CODE_INVALID_STREAM: u32 = 5;
86const CONNECTION_CLOSE_REASON_INVALID_STREAM: &[u8] = b"invalid_stream";
87
88/// The new connections per minute from a particular IP address.
89/// Heuristically set to the default maximum concurrent connections
90/// per IP address. Might be adjusted later.
91#[deprecated(
92    since = "2.2.0",
93    note = "Use solana_streamer::quic::DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE"
94)]
95pub use crate::quic::DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE;
96/// Limit to 250K PPS
97#[deprecated(
98    since = "2.2.0",
99    note = "Use solana_streamer::quic::DEFAULT_MAX_STREAMS_PER_MS"
100)]
101pub use crate::quic::DEFAULT_MAX_STREAMS_PER_MS;
102
103/// Total new connection counts per second. Heuristically taken from
104/// the default staked and unstaked connection limits. Might be adjusted
105/// later.
106const TOTAL_CONNECTIONS_PER_SECOND: u64 = 2500;
107
108/// The threshold of the size of the connection rate limiter map. When
109/// the map size is above this, we will trigger a cleanup of older
110/// entries used by past requests.
111const CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD: usize = 100_000;
112
113// A struct to accumulate the bytes making up
114// a packet, along with their offsets, and the
115// packet metadata. We use this accumulator to avoid
116// multiple copies of the Bytes (when building up
117// the Packet and then when copying the Packet into a PacketBatch)
118#[derive(Clone)]
119struct PacketAccumulator {
120    pub meta: Meta,
121    pub chunks: SmallVec<[Bytes; 2]>,
122    pub start_time: Instant,
123}
124
125impl PacketAccumulator {
126    fn new(meta: Meta) -> Self {
127        Self {
128            meta,
129            chunks: SmallVec::default(),
130            start_time: Instant::now(),
131        }
132    }
133}
134
135#[derive(Copy, Clone, Debug)]
136pub enum ConnectionPeerType {
137    Unstaked,
138    Staked(u64),
139}
140
141impl ConnectionPeerType {
142    pub(crate) fn is_staked(&self) -> bool {
143        matches!(self, ConnectionPeerType::Staked(_))
144    }
145}
146
147pub struct SpawnNonBlockingServerResult {
148    pub endpoints: Vec<Endpoint>,
149    pub stats: Arc<StreamerStats>,
150    pub thread: JoinHandle<()>,
151    pub max_concurrent_connections: usize,
152}
153
154pub fn spawn_server(
155    name: &'static str,
156    sock: UdpSocket,
157    keypair: &Keypair,
158    packet_sender: Sender<PacketBatch>,
159    exit: Arc<AtomicBool>,
160    staked_nodes: Arc<RwLock<StakedNodes>>,
161    quic_server_params: QuicServerParams,
162) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
163    spawn_server_multi(
164        name,
165        vec![sock],
166        keypair,
167        packet_sender,
168        exit,
169        staked_nodes,
170        quic_server_params,
171    )
172}
173
174pub fn spawn_server_multi(
175    name: &'static str,
176    sockets: Vec<UdpSocket>,
177    keypair: &Keypair,
178    packet_sender: Sender<PacketBatch>,
179    exit: Arc<AtomicBool>,
180    staked_nodes: Arc<RwLock<StakedNodes>>,
181    quic_server_params: QuicServerParams,
182) -> Result<SpawnNonBlockingServerResult, QuicServerError> {
183    info!("Start {name} quic server on {sockets:?}");
184    let QuicServerParams {
185        max_unstaked_connections,
186        max_staked_connections,
187        max_connections_per_peer,
188        max_streams_per_ms,
189        max_connections_per_ipaddr_per_min,
190        wait_for_chunk_timeout,
191        coalesce,
192        coalesce_channel_size,
193    } = quic_server_params;
194    let concurrent_connections = max_staked_connections + max_unstaked_connections;
195    let max_concurrent_connections = concurrent_connections + concurrent_connections / 4;
196    let (config, _) = configure_server(keypair)?;
197
198    let endpoints = sockets
199        .into_iter()
200        .map(|sock| {
201            Endpoint::new(
202                EndpointConfig::default(),
203                Some(config.clone()),
204                sock,
205                Arc::new(TokioRuntime),
206            )
207            .map_err(QuicServerError::EndpointFailed)
208        })
209        .collect::<Result<Vec<_>, _>>()?;
210    let stats = Arc::<StreamerStats>::default();
211    let handle = tokio::spawn(run_server(
212        name,
213        endpoints.clone(),
214        packet_sender,
215        exit,
216        max_connections_per_peer,
217        staked_nodes,
218        max_staked_connections,
219        max_unstaked_connections,
220        max_streams_per_ms,
221        max_connections_per_ipaddr_per_min,
222        stats.clone(),
223        wait_for_chunk_timeout,
224        coalesce,
225        coalesce_channel_size,
226        max_concurrent_connections,
227    ));
228    Ok(SpawnNonBlockingServerResult {
229        endpoints,
230        stats,
231        thread: handle,
232        max_concurrent_connections,
233    })
234}
235
236/// struct ease tracking connections of all stages, so that we do not have to
237/// litter the code with open connection tracking. This is added into the
238/// connection table as part of the ConnectionEntry. The reference is auto
239/// reduced when it is dropped.
240struct ClientConnectionTracker {
241    stats: Arc<StreamerStats>,
242}
243
244/// This is required by ConnectionEntry for supporting debug format.
245impl fmt::Debug for ClientConnectionTracker {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        f.debug_struct("StreamerClientConnection")
248            .field(
249                "open_connections:",
250                &self.stats.open_connections.load(Ordering::Relaxed),
251            )
252            .finish()
253    }
254}
255
256impl Drop for ClientConnectionTracker {
257    /// When this is dropped, reduce the open connection count.
258    fn drop(&mut self) {
259        self.stats.open_connections.fetch_sub(1, Ordering::Relaxed);
260    }
261}
262
263impl ClientConnectionTracker {
264    /// Check the max_concurrent_connections limit and if it is within the limit
265    /// create ClientConnectionTracker and increment open connection count. Otherwise returns Err
266    fn new(stats: Arc<StreamerStats>, max_concurrent_connections: usize) -> Result<Self, ()> {
267        let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed);
268        if open_connections >= max_concurrent_connections {
269            stats.open_connections.fetch_sub(1, Ordering::Relaxed);
270            debug!(
271                "There are too many concurrent connections opened already: open: {}, max: {}",
272                open_connections, max_concurrent_connections
273            );
274            return Err(());
275        }
276
277        Ok(Self { stats })
278    }
279}
280
281#[allow(clippy::too_many_arguments)]
282async fn run_server(
283    name: &'static str,
284    endpoints: Vec<Endpoint>,
285    packet_sender: Sender<PacketBatch>,
286    exit: Arc<AtomicBool>,
287    max_connections_per_peer: usize,
288    staked_nodes: Arc<RwLock<StakedNodes>>,
289    max_staked_connections: usize,
290    max_unstaked_connections: usize,
291    max_streams_per_ms: u64,
292    max_connections_per_ipaddr_per_min: u64,
293    stats: Arc<StreamerStats>,
294    wait_for_chunk_timeout: Duration,
295    coalesce: Duration,
296    coalesce_channel_size: usize,
297    max_concurrent_connections: usize,
298) {
299    let rate_limiter = ConnectionRateLimiter::new(max_connections_per_ipaddr_per_min);
300    let overall_connection_rate_limiter =
301        TotalConnectionRateLimiter::new(TOTAL_CONNECTIONS_PER_SECOND);
302
303    const WAIT_FOR_CONNECTION_TIMEOUT: Duration = Duration::from_secs(1);
304    debug!("spawn quic server");
305    let mut last_datapoint = Instant::now();
306    let unstaked_connection_table: Arc<Mutex<ConnectionTable>> =
307        Arc::new(Mutex::new(ConnectionTable::new()));
308    let stream_load_ema = Arc::new(StakedStreamLoadEMA::new(
309        stats.clone(),
310        max_unstaked_connections,
311        max_streams_per_ms,
312    ));
313    stats
314        .quic_endpoints_count
315        .store(endpoints.len(), Ordering::Relaxed);
316    let staked_connection_table: Arc<Mutex<ConnectionTable>> =
317        Arc::new(Mutex::new(ConnectionTable::new()));
318    let (sender, receiver) = async_bounded(coalesce_channel_size);
319    tokio::spawn(packet_batch_sender(
320        packet_sender,
321        receiver,
322        exit.clone(),
323        stats.clone(),
324        coalesce,
325    ));
326
327    let mut accepts = endpoints
328        .iter()
329        .enumerate()
330        .map(|(i, incoming)| {
331            Box::pin(EndpointAccept {
332                accept: incoming.accept(),
333                endpoint: i,
334            })
335        })
336        .collect::<FuturesUnordered<_>>();
337
338    while !exit.load(Ordering::Relaxed) {
339        let timeout_connection = select! {
340            ready = accepts.next() => {
341                if let Some((connecting, i)) = ready {
342                    accepts.push(
343                        Box::pin(EndpointAccept {
344                            accept: endpoints[i].accept(),
345                            endpoint: i,
346                        }
347                    ));
348                    Ok(connecting)
349                } else {
350                    // we can't really get here - we never poll an empty FuturesUnordered
351                    continue
352                }
353            }
354            _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => {
355                Err(())
356            }
357        };
358
359        if last_datapoint.elapsed().as_secs() >= 5 {
360            stats.report(name);
361            last_datapoint = Instant::now();
362        }
363
364        if let Ok(Some(incoming)) = timeout_connection {
365            stats
366                .total_incoming_connection_attempts
367                .fetch_add(1, Ordering::Relaxed);
368
369            let remote_address = incoming.remote_address();
370
371            // first do per IpAddr rate limiting
372            if rate_limiter.len() > CONNECTION_RATE_LIMITER_CLEANUP_SIZE_THRESHOLD {
373                rate_limiter.retain_recent();
374            }
375            stats
376                .connection_rate_limiter_length
377                .store(rate_limiter.len(), Ordering::Relaxed);
378            debug!("Got a connection {remote_address:?}");
379            if !rate_limiter.is_allowed(&remote_address.ip()) {
380                debug!(
381                    "Reject connection from {:?} -- rate limiting exceeded",
382                    remote_address
383                );
384                stats
385                    .connection_rate_limited_per_ipaddr
386                    .fetch_add(1, Ordering::Relaxed);
387                incoming.ignore();
388                continue;
389            }
390
391            // then check overall connection rate limit:
392            if !overall_connection_rate_limiter.is_allowed() {
393                debug!(
394                    "Reject connection from {:?} -- total rate limiting exceeded",
395                    remote_address.ip()
396                );
397                stats
398                    .connection_rate_limited_across_all
399                    .fetch_add(1, Ordering::Relaxed);
400                incoming.ignore();
401                continue;
402            }
403
404            let Ok(client_connection_tracker) =
405                ClientConnectionTracker::new(stats.clone(), max_concurrent_connections)
406            else {
407                stats
408                    .refused_connections_too_many_open_connections
409                    .fetch_add(1, Ordering::Relaxed);
410                incoming.refuse();
411                continue;
412            };
413
414            stats
415                .outstanding_incoming_connection_attempts
416                .fetch_add(1, Ordering::Relaxed);
417            let connecting = incoming.accept();
418            match connecting {
419                Ok(connecting) => {
420                    tokio::spawn(setup_connection(
421                        connecting,
422                        client_connection_tracker,
423                        unstaked_connection_table.clone(),
424                        staked_connection_table.clone(),
425                        sender.clone(),
426                        max_connections_per_peer,
427                        staked_nodes.clone(),
428                        max_staked_connections,
429                        max_unstaked_connections,
430                        max_streams_per_ms,
431                        stats.clone(),
432                        wait_for_chunk_timeout,
433                        stream_load_ema.clone(),
434                    ));
435                }
436                Err(err) => {
437                    debug!("Incoming::accept(): error {:?}", err);
438                }
439            }
440        } else {
441            debug!("accept(): Timed out waiting for connection");
442        }
443    }
444}
445
446fn prune_unstaked_connection_table(
447    unstaked_connection_table: &mut ConnectionTable,
448    max_unstaked_connections: usize,
449    stats: Arc<StreamerStats>,
450) {
451    if unstaked_connection_table.total_size >= max_unstaked_connections {
452        const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90;
453        let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE);
454
455        let max_connections = max_percentage_full.apply_to(max_unstaked_connections);
456        let num_pruned = unstaked_connection_table.prune_oldest(max_connections);
457        stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
458    }
459}
460
461pub fn get_remote_pubkey(connection: &Connection) -> Option<Pubkey> {
462    // Use the client cert only if it is self signed and the chain length is 1.
463    connection
464        .peer_identity()?
465        .downcast::<Vec<rustls::pki_types::CertificateDer>>()
466        .ok()
467        .filter(|certs| certs.len() == 1)?
468        .first()
469        .and_then(get_pubkey_from_tls_certificate)
470}
471
472fn get_connection_stake(
473    connection: &Connection,
474    staked_nodes: &RwLock<StakedNodes>,
475) -> Option<(Pubkey, u64, u64, u64, u64)> {
476    let pubkey = get_remote_pubkey(connection)?;
477    debug!("Peer public key is {pubkey:?}");
478    let staked_nodes = staked_nodes.read().unwrap();
479    Some((
480        pubkey,
481        staked_nodes.get_node_stake(&pubkey)?,
482        staked_nodes.total_stake(),
483        staked_nodes.max_stake(),
484        staked_nodes.min_stake(),
485    ))
486}
487
488pub fn compute_max_allowed_uni_streams(peer_type: ConnectionPeerType, total_stake: u64) -> usize {
489    match peer_type {
490        ConnectionPeerType::Staked(peer_stake) => {
491            // No checked math for f64 type. So let's explicitly check for 0 here
492            if total_stake == 0 || peer_stake > total_stake {
493                warn!(
494                    "Invalid stake values: peer_stake: {:?}, total_stake: {:?}",
495                    peer_stake, total_stake,
496                );
497
498                QUIC_MIN_STAKED_CONCURRENT_STREAMS
499            } else {
500                let delta = (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS
501                    - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
502
503                (((peer_stake as f64 / total_stake as f64) * delta) as usize
504                    + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
505                    .clamp(
506                        QUIC_MIN_STAKED_CONCURRENT_STREAMS,
507                        QUIC_MAX_STAKED_CONCURRENT_STREAMS,
508                    )
509            }
510        }
511        ConnectionPeerType::Unstaked => QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
512    }
513}
514
515enum ConnectionHandlerError {
516    ConnectionAddError,
517    MaxStreamError,
518}
519
520#[derive(Clone)]
521struct NewConnectionHandlerParams {
522    // In principle, the code can be made to work with a crossbeam channel
523    // as long as we're careful never to use a blocking recv or send call
524    // but I've found that it's simply too easy to accidentally block
525    // in async code when using the crossbeam channel, so for the sake of maintainability,
526    // we're sticking with an async channel
527    packet_sender: AsyncSender<PacketAccumulator>,
528    remote_pubkey: Option<Pubkey>,
529    peer_type: ConnectionPeerType,
530    total_stake: u64,
531    max_connections_per_peer: usize,
532    stats: Arc<StreamerStats>,
533    max_stake: u64,
534    min_stake: u64,
535}
536
537impl NewConnectionHandlerParams {
538    fn new_unstaked(
539        packet_sender: AsyncSender<PacketAccumulator>,
540        max_connections_per_peer: usize,
541        stats: Arc<StreamerStats>,
542    ) -> NewConnectionHandlerParams {
543        NewConnectionHandlerParams {
544            packet_sender,
545            remote_pubkey: None,
546            peer_type: ConnectionPeerType::Unstaked,
547            total_stake: 0,
548            max_connections_per_peer,
549            stats,
550            max_stake: 0,
551            min_stake: 0,
552        }
553    }
554}
555
556fn handle_and_cache_new_connection(
557    client_connection_tracker: ClientConnectionTracker,
558    connection: Connection,
559    mut connection_table_l: MutexGuard<ConnectionTable>,
560    connection_table: Arc<Mutex<ConnectionTable>>,
561    params: &NewConnectionHandlerParams,
562    wait_for_chunk_timeout: Duration,
563    stream_load_ema: Arc<StakedStreamLoadEMA>,
564) -> Result<(), ConnectionHandlerError> {
565    if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams(
566        params.peer_type,
567        params.total_stake,
568    ) as u64)
569    {
570        let remote_addr = connection.remote_address();
571        let receive_window =
572            compute_recieve_window(params.max_stake, params.min_stake, params.peer_type);
573
574        debug!(
575            "Peer type {:?}, total stake {}, max streams {} receive_window {:?} from peer {}",
576            params.peer_type,
577            params.total_stake,
578            max_uni_streams.into_inner(),
579            receive_window,
580            remote_addr,
581        );
582
583        if let Some((last_update, cancel_connection, stream_counter)) = connection_table_l
584            .try_add_connection(
585                ConnectionTableKey::new(remote_addr.ip(), params.remote_pubkey),
586                remote_addr.port(),
587                client_connection_tracker,
588                Some(connection.clone()),
589                params.peer_type,
590                timing::timestamp(),
591                params.max_connections_per_peer,
592            )
593        {
594            drop(connection_table_l);
595
596            if let Ok(receive_window) = receive_window {
597                connection.set_receive_window(receive_window);
598            }
599            connection.set_max_concurrent_uni_streams(max_uni_streams);
600
601            tokio::spawn(handle_connection(
602                connection,
603                remote_addr,
604                last_update,
605                connection_table,
606                cancel_connection,
607                params.clone(),
608                wait_for_chunk_timeout,
609                stream_load_ema,
610                stream_counter,
611            ));
612            Ok(())
613        } else {
614            params
615                .stats
616                .connection_add_failed
617                .fetch_add(1, Ordering::Relaxed);
618            Err(ConnectionHandlerError::ConnectionAddError)
619        }
620    } else {
621        connection.close(
622            CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT.into(),
623            CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT,
624        );
625        params
626            .stats
627            .connection_add_failed_invalid_stream_count
628            .fetch_add(1, Ordering::Relaxed);
629        Err(ConnectionHandlerError::MaxStreamError)
630    }
631}
632
633async fn prune_unstaked_connections_and_add_new_connection(
634    client_connection_tracker: ClientConnectionTracker,
635    connection: Connection,
636    connection_table: Arc<Mutex<ConnectionTable>>,
637    max_connections: usize,
638    params: &NewConnectionHandlerParams,
639    wait_for_chunk_timeout: Duration,
640    stream_load_ema: Arc<StakedStreamLoadEMA>,
641) -> Result<(), ConnectionHandlerError> {
642    let stats = params.stats.clone();
643    if max_connections > 0 {
644        let connection_table_clone = connection_table.clone();
645        let mut connection_table = connection_table.lock().await;
646        prune_unstaked_connection_table(&mut connection_table, max_connections, stats);
647        handle_and_cache_new_connection(
648            client_connection_tracker,
649            connection,
650            connection_table,
651            connection_table_clone,
652            params,
653            wait_for_chunk_timeout,
654            stream_load_ema,
655        )
656    } else {
657        connection.close(
658            CONNECTION_CLOSE_CODE_DISALLOWED.into(),
659            CONNECTION_CLOSE_REASON_DISALLOWED,
660        );
661        Err(ConnectionHandlerError::ConnectionAddError)
662    }
663}
664
665/// Calculate the ratio for per connection receive window from a staked peer
666fn compute_receive_window_ratio_for_staked_node(max_stake: u64, min_stake: u64, stake: u64) -> u64 {
667    // Testing shows the maximum througput from a connection is achieved at receive_window =
668    // PACKET_DATA_SIZE * 10. Beyond that, there is not much gain. We linearly map the
669    // stake to the ratio range from QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO to
670    // QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO. Where the linear algebra of finding the ratio 'r'
671    // for stake 's' is,
672    // r(s) = a * s + b. Given the max_stake, min_stake, max_ratio, min_ratio, we can find
673    // a and b.
674
675    if stake > max_stake {
676        return QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
677    }
678
679    let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
680    let min_ratio = QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO;
681    if max_stake > min_stake {
682        let a = (max_ratio - min_ratio) as f64 / (max_stake - min_stake) as f64;
683        let b = max_ratio as f64 - ((max_stake as f64) * a);
684        let ratio = (a * stake as f64) + b;
685        ratio.round() as u64
686    } else {
687        QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO
688    }
689}
690
691fn compute_recieve_window(
692    max_stake: u64,
693    min_stake: u64,
694    peer_type: ConnectionPeerType,
695) -> Result<VarInt, VarIntBoundsExceeded> {
696    match peer_type {
697        ConnectionPeerType::Unstaked => {
698            VarInt::from_u64(PACKET_DATA_SIZE as u64 * QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO)
699        }
700        ConnectionPeerType::Staked(peer_stake) => {
701            let ratio =
702                compute_receive_window_ratio_for_staked_node(max_stake, min_stake, peer_stake);
703            VarInt::from_u64(PACKET_DATA_SIZE as u64 * ratio)
704        }
705    }
706}
707
708#[allow(clippy::too_many_arguments)]
709async fn setup_connection(
710    connecting: Connecting,
711    client_connection_tracker: ClientConnectionTracker,
712    unstaked_connection_table: Arc<Mutex<ConnectionTable>>,
713    staked_connection_table: Arc<Mutex<ConnectionTable>>,
714    packet_sender: AsyncSender<PacketAccumulator>,
715    max_connections_per_peer: usize,
716    staked_nodes: Arc<RwLock<StakedNodes>>,
717    max_staked_connections: usize,
718    max_unstaked_connections: usize,
719    max_streams_per_ms: u64,
720    stats: Arc<StreamerStats>,
721    wait_for_chunk_timeout: Duration,
722    stream_load_ema: Arc<StakedStreamLoadEMA>,
723) {
724    const PRUNE_RANDOM_SAMPLE_SIZE: usize = 2;
725    let from = connecting.remote_address();
726    let res = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await;
727    stats
728        .outstanding_incoming_connection_attempts
729        .fetch_sub(1, Ordering::Relaxed);
730    if let Ok(connecting_result) = res {
731        match connecting_result {
732            Ok(new_connection) => {
733                stats.total_new_connections.fetch_add(1, Ordering::Relaxed);
734
735                let params = get_connection_stake(&new_connection, &staked_nodes).map_or(
736                    NewConnectionHandlerParams::new_unstaked(
737                        packet_sender.clone(),
738                        max_connections_per_peer,
739                        stats.clone(),
740                    ),
741                    |(pubkey, stake, total_stake, max_stake, min_stake)| {
742                        // The heuristic is that the stake should be large engouh to have 1 stream pass throuh within one throttle
743                        // interval during which we allow max (MAX_STREAMS_PER_MS * STREAM_THROTTLING_INTERVAL_MS) streams.
744                        let min_stake_ratio =
745                            1_f64 / (max_streams_per_ms * STREAM_THROTTLING_INTERVAL_MS) as f64;
746                        let stake_ratio = stake as f64 / total_stake as f64;
747                        let peer_type = if stake_ratio < min_stake_ratio {
748                            // If it is a staked connection with ultra low stake ratio, treat it as unstaked.
749                            ConnectionPeerType::Unstaked
750                        } else {
751                            ConnectionPeerType::Staked(stake)
752                        };
753                        NewConnectionHandlerParams {
754                            packet_sender,
755                            remote_pubkey: Some(pubkey),
756                            peer_type,
757                            total_stake,
758                            max_connections_per_peer,
759                            stats: stats.clone(),
760                            max_stake,
761                            min_stake,
762                        }
763                    },
764                );
765
766                match params.peer_type {
767                    ConnectionPeerType::Staked(stake) => {
768                        let mut connection_table_l = staked_connection_table.lock().await;
769
770                        if connection_table_l.total_size >= max_staked_connections {
771                            let num_pruned =
772                                connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake);
773                            stats.num_evictions.fetch_add(num_pruned, Ordering::Relaxed);
774                        }
775
776                        if connection_table_l.total_size < max_staked_connections {
777                            if let Ok(()) = handle_and_cache_new_connection(
778                                client_connection_tracker,
779                                new_connection,
780                                connection_table_l,
781                                staked_connection_table.clone(),
782                                &params,
783                                wait_for_chunk_timeout,
784                                stream_load_ema.clone(),
785                            ) {
786                                stats
787                                    .connection_added_from_staked_peer
788                                    .fetch_add(1, Ordering::Relaxed);
789                            }
790                        } else {
791                            // If we couldn't prune a connection in the staked connection table, let's
792                            // put this connection in the unstaked connection table. If needed, prune a
793                            // connection from the unstaked connection table.
794                            if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
795                                client_connection_tracker,
796                                new_connection,
797                                unstaked_connection_table.clone(),
798                                max_unstaked_connections,
799                                &params,
800                                wait_for_chunk_timeout,
801                                stream_load_ema.clone(),
802                            )
803                            .await
804                            {
805                                stats
806                                    .connection_added_from_staked_peer
807                                    .fetch_add(1, Ordering::Relaxed);
808                            } else {
809                                stats
810                                    .connection_add_failed_on_pruning
811                                    .fetch_add(1, Ordering::Relaxed);
812                                stats
813                                    .connection_add_failed_staked_node
814                                    .fetch_add(1, Ordering::Relaxed);
815                            }
816                        }
817                    }
818                    ConnectionPeerType::Unstaked => {
819                        if let Ok(()) = prune_unstaked_connections_and_add_new_connection(
820                            client_connection_tracker,
821                            new_connection,
822                            unstaked_connection_table.clone(),
823                            max_unstaked_connections,
824                            &params,
825                            wait_for_chunk_timeout,
826                            stream_load_ema.clone(),
827                        )
828                        .await
829                        {
830                            stats
831                                .connection_added_from_unstaked_peer
832                                .fetch_add(1, Ordering::Relaxed);
833                        } else {
834                            stats
835                                .connection_add_failed_unstaked_node
836                                .fetch_add(1, Ordering::Relaxed);
837                        }
838                    }
839                }
840            }
841            Err(e) => {
842                handle_connection_error(e, &stats, from);
843            }
844        }
845    } else {
846        stats
847            .connection_setup_timeout
848            .fetch_add(1, Ordering::Relaxed);
849    }
850}
851
852fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, from: SocketAddr) {
853    debug!("error: {:?} from: {:?}", e, from);
854    stats.connection_setup_error.fetch_add(1, Ordering::Relaxed);
855    match e {
856        quinn::ConnectionError::TimedOut => {
857            stats
858                .connection_setup_error_timed_out
859                .fetch_add(1, Ordering::Relaxed);
860        }
861        quinn::ConnectionError::ConnectionClosed(_) => {
862            stats
863                .connection_setup_error_closed
864                .fetch_add(1, Ordering::Relaxed);
865        }
866        quinn::ConnectionError::TransportError(_) => {
867            stats
868                .connection_setup_error_transport
869                .fetch_add(1, Ordering::Relaxed);
870        }
871        quinn::ConnectionError::ApplicationClosed(_) => {
872            stats
873                .connection_setup_error_app_closed
874                .fetch_add(1, Ordering::Relaxed);
875        }
876        quinn::ConnectionError::Reset => {
877            stats
878                .connection_setup_error_reset
879                .fetch_add(1, Ordering::Relaxed);
880        }
881        quinn::ConnectionError::LocallyClosed => {
882            stats
883                .connection_setup_error_locally_closed
884                .fetch_add(1, Ordering::Relaxed);
885        }
886        _ => {}
887    }
888}
889
890// Holder(s) of the AsyncSender<PacketAccumulator> on the other end should not
891// wait for this function to exit
892async fn packet_batch_sender(
893    packet_sender: Sender<PacketBatch>,
894    packet_receiver: AsyncReceiver<PacketAccumulator>,
895    exit: Arc<AtomicBool>,
896    stats: Arc<StreamerStats>,
897    coalesce: Duration,
898) {
899    trace!("enter packet_batch_sender");
900    let recycler = PacketBatchRecycler::default();
901    let mut batch_start_time = Instant::now();
902    loop {
903        let mut packet_perf_measure: Vec<([u8; 64], Instant)> = Vec::default();
904        let mut packet_batch =
905            PacketBatch::new_with_recycler(&recycler, PACKETS_PER_BATCH, "quic_packet_coalescer");
906        let mut total_bytes: usize = 0;
907
908        stats
909            .total_packet_batches_allocated
910            .fetch_add(1, Ordering::Relaxed);
911        stats
912            .total_packets_allocated
913            .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed);
914
915        loop {
916            if exit.load(Ordering::Relaxed) {
917                return;
918            }
919            let elapsed = batch_start_time.elapsed();
920            if packet_batch.len() >= PACKETS_PER_BATCH
921                || (!packet_batch.is_empty() && elapsed >= coalesce)
922            {
923                let len = packet_batch.len();
924                track_streamer_fetch_packet_performance(&packet_perf_measure, &stats);
925
926                if let Err(e) = packet_sender.send(packet_batch) {
927                    stats
928                        .total_packet_batch_send_err
929                        .fetch_add(1, Ordering::Relaxed);
930                    trace!("Send error: {}", e);
931                } else {
932                    stats
933                        .total_packet_batches_sent
934                        .fetch_add(1, Ordering::Relaxed);
935
936                    stats
937                        .total_packets_sent_to_consumer
938                        .fetch_add(len, Ordering::Relaxed);
939
940                    stats
941                        .total_bytes_sent_to_consumer
942                        .fetch_add(total_bytes, Ordering::Relaxed);
943
944                    trace!("Sent {} packet batch", len);
945                }
946                break;
947            }
948
949            let timeout_res = if !packet_batch.is_empty() {
950                // If we get here, elapsed < coalesce (see above if condition)
951                timeout(coalesce - elapsed, packet_receiver.recv()).await
952            } else {
953                // Small bit of non-idealness here: the holder(s) of the other end
954                // of packet_receiver must drop it (without waiting for us to exit)
955                // or we have a chance of sleeping here forever
956                // and never polling exit. Not a huge deal in practice as the
957                // only time this happens is when we tear down the server
958                // and at that time the other end does indeed not wait for us
959                // to exit here
960                Ok(packet_receiver.recv().await)
961            };
962
963            if let Ok(Ok(packet_accumulator)) = timeout_res {
964                // Start the timeout from when the packet batch first becomes non-empty
965                if packet_batch.is_empty() {
966                    batch_start_time = Instant::now();
967                }
968
969                unsafe {
970                    packet_batch.set_len(packet_batch.len() + 1);
971                }
972
973                let i = packet_batch.len() - 1;
974                *packet_batch[i].meta_mut() = packet_accumulator.meta;
975                let num_chunks = packet_accumulator.chunks.len();
976                let mut offset = 0;
977                for chunk in packet_accumulator.chunks {
978                    packet_batch[i].buffer_mut()[offset..offset + chunk.len()]
979                        .copy_from_slice(&chunk);
980                    offset += chunk.len();
981                }
982
983                total_bytes += packet_batch[i].meta().size;
984
985                if let Some(signature) = signature_if_should_track_packet(&packet_batch[i])
986                    .ok()
987                    .flatten()
988                {
989                    packet_perf_measure.push((*signature, packet_accumulator.start_time));
990                    // we set the PERF_TRACK_PACKET on
991                    packet_batch[i].meta_mut().set_track_performance(true);
992                }
993                stats
994                    .total_chunks_processed_by_batcher
995                    .fetch_add(num_chunks, Ordering::Relaxed);
996            }
997        }
998    }
999}
1000
1001fn track_streamer_fetch_packet_performance(
1002    packet_perf_measure: &[([u8; 64], Instant)],
1003    stats: &StreamerStats,
1004) {
1005    if packet_perf_measure.is_empty() {
1006        return;
1007    }
1008    let mut measure = Measure::start("track_perf");
1009    let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap();
1010
1011    let now = Instant::now();
1012    for (signature, start_time) in packet_perf_measure {
1013        let duration = now.duration_since(*start_time);
1014        debug!(
1015            "QUIC streamer fetch stage took {duration:?} for transaction {:?}",
1016            Signature::from(*signature)
1017        );
1018        process_sampled_packets_us_hist
1019            .increment(duration.as_micros() as u64)
1020            .unwrap();
1021    }
1022
1023    drop(process_sampled_packets_us_hist);
1024    measure.stop();
1025    stats
1026        .perf_track_overhead_us
1027        .fetch_add(measure.as_us(), Ordering::Relaxed);
1028}
1029
1030async fn handle_connection(
1031    connection: Connection,
1032    remote_addr: SocketAddr,
1033    last_update: Arc<AtomicU64>,
1034    connection_table: Arc<Mutex<ConnectionTable>>,
1035    cancel: CancellationToken,
1036    params: NewConnectionHandlerParams,
1037    wait_for_chunk_timeout: Duration,
1038    stream_load_ema: Arc<StakedStreamLoadEMA>,
1039    stream_counter: Arc<ConnectionStreamCounter>,
1040) {
1041    let NewConnectionHandlerParams {
1042        packet_sender,
1043        peer_type,
1044        remote_pubkey,
1045        stats,
1046        total_stake,
1047        ..
1048    } = params;
1049
1050    debug!(
1051        "quic new connection {} streams: {} connections: {}",
1052        remote_addr,
1053        stats.total_streams.load(Ordering::Relaxed),
1054        stats.total_connections.load(Ordering::Relaxed),
1055    );
1056    stats.total_connections.fetch_add(1, Ordering::Relaxed);
1057
1058    'conn: loop {
1059        // Wait for new streams. If the peer is disconnected we get a cancellation signal and stop
1060        // the connection task.
1061        let mut stream = select! {
1062            stream = connection.accept_uni() => match stream {
1063                Ok(stream) => stream,
1064                Err(e) => {
1065                    debug!("stream error: {:?}", e);
1066                    break;
1067                }
1068            },
1069            _ = cancel.cancelled() => break,
1070        };
1071
1072        let max_streams_per_throttling_interval =
1073            stream_load_ema.available_load_capacity_in_throttling_duration(peer_type, total_stake);
1074
1075        let throttle_interval_start = stream_counter.reset_throttling_params_if_needed();
1076        let streams_read_in_throttle_interval = stream_counter.stream_count.load(Ordering::Relaxed);
1077        if streams_read_in_throttle_interval >= max_streams_per_throttling_interval {
1078            // The peer is sending faster than we're willing to read. Sleep for what's
1079            // left of this read interval so the peer backs off.
1080            let throttle_duration =
1081                STREAM_THROTTLING_INTERVAL.saturating_sub(throttle_interval_start.elapsed());
1082
1083            if !throttle_duration.is_zero() {
1084                debug!("Throttling stream from {remote_addr:?}, peer type: {:?}, total stake: {}, \
1085                                    max_streams_per_interval: {max_streams_per_throttling_interval}, read_interval_streams: {streams_read_in_throttle_interval} \
1086                                    throttle_duration: {throttle_duration:?}",
1087                                    peer_type, total_stake);
1088                stats.throttled_streams.fetch_add(1, Ordering::Relaxed);
1089                match peer_type {
1090                    ConnectionPeerType::Unstaked => {
1091                        stats
1092                            .throttled_unstaked_streams
1093                            .fetch_add(1, Ordering::Relaxed);
1094                    }
1095                    ConnectionPeerType::Staked(_) => {
1096                        stats
1097                            .throttled_staked_streams
1098                            .fetch_add(1, Ordering::Relaxed);
1099                    }
1100                }
1101                sleep(throttle_duration).await;
1102            }
1103        }
1104        stream_load_ema.increment_load(peer_type);
1105        stream_counter.stream_count.fetch_add(1, Ordering::Relaxed);
1106        stats.total_streams.fetch_add(1, Ordering::Relaxed);
1107        stats.total_new_streams.fetch_add(1, Ordering::Relaxed);
1108
1109        let mut meta = Meta::default();
1110        meta.set_socket_addr(&remote_addr);
1111        meta.set_from_staked_node(matches!(peer_type, ConnectionPeerType::Staked(_)));
1112        let mut accum = PacketAccumulator::new(meta);
1113
1114        // Virtually all small transactions will fit in 1 chunk. Larger transactions will fit in 1
1115        // or 2 chunks if the first chunk starts towards the end of a datagram. A small number of
1116        // transaction will have other protocol frames inserted in the middle. Empirically it's been
1117        // observed that 4 is the maximum number of chunks txs get split into.
1118        //
1119        // Bytes values are small, so overall the array takes only 128 bytes, and the "cost" of
1120        // overallocating a few bytes is negligible compared to the cost of having to do multiple
1121        // read_chunks() calls.
1122        let mut chunks: [Bytes; 4] = array::from_fn(|_| Bytes::new());
1123
1124        loop {
1125            // Read the next chunks, waiting up to `wait_for_chunk_timeout`. If we don't get chunks
1126            // before then, we assume the stream is dead. This can only happen if there's severe
1127            // packet loss or the peer stops sending for whatever reason.
1128            let n_chunks = match tokio::select! {
1129                chunk = tokio::time::timeout(
1130                    wait_for_chunk_timeout,
1131                    stream.read_chunks(&mut chunks)) => chunk,
1132
1133                // If the peer gets disconnected stop the task right away.
1134                _ = cancel.cancelled() => break,
1135            } {
1136                // read_chunk returned success
1137                Ok(Ok(chunk)) => chunk.unwrap_or(0),
1138                // read_chunk returned error
1139                Ok(Err(e)) => {
1140                    debug!("Received stream error: {:?}", e);
1141                    stats
1142                        .total_stream_read_errors
1143                        .fetch_add(1, Ordering::Relaxed);
1144                    break;
1145                }
1146                // timeout elapsed
1147                Err(_) => {
1148                    debug!("Timeout in receiving on stream");
1149                    stats
1150                        .total_stream_read_timeouts
1151                        .fetch_add(1, Ordering::Relaxed);
1152                    break;
1153                }
1154            };
1155
1156            match handle_chunks(
1157                // Bytes::clone() is a cheap atomic inc
1158                chunks.iter().take(n_chunks).cloned(),
1159                &mut accum,
1160                &packet_sender,
1161                &stats,
1162                peer_type,
1163            )
1164            .await
1165            {
1166                // The stream is finished, break out of the loop and close the stream.
1167                Ok(StreamState::Finished) => {
1168                    last_update.store(timing::timestamp(), Ordering::Relaxed);
1169                    break;
1170                }
1171                // The stream is still active, continue reading.
1172                Ok(StreamState::Receiving) => {}
1173                Err(_) => {
1174                    // Disconnect peers that send invalid streams.
1175                    connection.close(
1176                        CONNECTION_CLOSE_CODE_INVALID_STREAM.into(),
1177                        CONNECTION_CLOSE_REASON_INVALID_STREAM,
1178                    );
1179                    stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1180                    stream_load_ema.update_ema_if_needed();
1181                    break 'conn;
1182                }
1183            }
1184        }
1185
1186        stats.total_streams.fetch_sub(1, Ordering::Relaxed);
1187        stream_load_ema.update_ema_if_needed();
1188    }
1189
1190    let stable_id = connection.stable_id();
1191    let removed_connection_count = connection_table.lock().await.remove_connection(
1192        ConnectionTableKey::new(remote_addr.ip(), remote_pubkey),
1193        remote_addr.port(),
1194        stable_id,
1195    );
1196    if removed_connection_count > 0 {
1197        stats
1198            .connection_removed
1199            .fetch_add(removed_connection_count, Ordering::Relaxed);
1200    } else {
1201        stats
1202            .connection_remove_failed
1203            .fetch_add(1, Ordering::Relaxed);
1204    }
1205    stats.total_connections.fetch_sub(1, Ordering::Relaxed);
1206}
1207
1208enum StreamState {
1209    // Stream is not finished, keep receiving chunks
1210    Receiving,
1211    // Stream is finished
1212    Finished,
1213}
1214
1215// Handle the chunks received from the stream. If the stream is finished, send the packet to the
1216// packet sender.
1217//
1218// Returns Err(()) if the stream is invalid.
1219async fn handle_chunks(
1220    chunks: impl ExactSizeIterator<Item = Bytes>,
1221    accum: &mut PacketAccumulator,
1222    packet_sender: &AsyncSender<PacketAccumulator>,
1223    stats: &StreamerStats,
1224    peer_type: ConnectionPeerType,
1225) -> Result<StreamState, ()> {
1226    let n_chunks = chunks.len();
1227    for chunk in chunks {
1228        accum.meta.size += chunk.len();
1229        if accum.meta.size > PACKET_DATA_SIZE {
1230            // The stream window size is set to PACKET_DATA_SIZE, so one individual chunk can
1231            // never exceed this size. A peer can send two chunks that together exceed the size
1232            // tho, in which case we report the error.
1233            stats.invalid_stream_size.fetch_add(1, Ordering::Relaxed);
1234            debug!("invalid stream size {}", accum.meta.size);
1235            return Err(());
1236        }
1237        accum.chunks.push(chunk);
1238        if peer_type.is_staked() {
1239            stats
1240                .total_staked_chunks_received
1241                .fetch_add(1, Ordering::Relaxed);
1242        } else {
1243            stats
1244                .total_unstaked_chunks_received
1245                .fetch_add(1, Ordering::Relaxed);
1246        }
1247    }
1248
1249    // n_chunks == 0 marks the end of a stream
1250    if n_chunks != 0 {
1251        return Ok(StreamState::Receiving);
1252    }
1253
1254    if accum.chunks.is_empty() {
1255        debug!("stream is empty");
1256        stats
1257            .total_packet_batches_none
1258            .fetch_add(1, Ordering::Relaxed);
1259        return Err(());
1260    }
1261
1262    // done receiving chunks
1263    let bytes_sent = accum.meta.size;
1264    let chunks_sent = accum.chunks.len();
1265
1266    if let Err(err) = packet_sender.send(accum.clone()).await {
1267        stats
1268            .total_handle_chunk_to_packet_batcher_send_err
1269            .fetch_add(1, Ordering::Relaxed);
1270        trace!("packet batch send error {:?}", err);
1271    } else {
1272        stats
1273            .total_packets_sent_for_batching
1274            .fetch_add(1, Ordering::Relaxed);
1275        stats
1276            .total_bytes_sent_for_batching
1277            .fetch_add(bytes_sent, Ordering::Relaxed);
1278        stats
1279            .total_chunks_sent_for_batching
1280            .fetch_add(chunks_sent, Ordering::Relaxed);
1281
1282        match peer_type {
1283            ConnectionPeerType::Unstaked => {
1284                stats
1285                    .total_unstaked_packets_sent_for_batching
1286                    .fetch_add(1, Ordering::Relaxed);
1287            }
1288            ConnectionPeerType::Staked(_) => {
1289                stats
1290                    .total_staked_packets_sent_for_batching
1291                    .fetch_add(1, Ordering::Relaxed);
1292            }
1293        }
1294
1295        trace!("sent {} byte packet for batching", bytes_sent);
1296    }
1297
1298    Ok(StreamState::Finished)
1299}
1300
1301#[derive(Debug)]
1302struct ConnectionEntry {
1303    cancel: CancellationToken,
1304    peer_type: ConnectionPeerType,
1305    last_update: Arc<AtomicU64>,
1306    port: u16,
1307    // We do not explicitly use it, but its drop is triggered when ConnectionEntry is dropped.
1308    _client_connection_tracker: ClientConnectionTracker,
1309    connection: Option<Connection>,
1310    stream_counter: Arc<ConnectionStreamCounter>,
1311}
1312
1313impl ConnectionEntry {
1314    fn new(
1315        cancel: CancellationToken,
1316        peer_type: ConnectionPeerType,
1317        last_update: Arc<AtomicU64>,
1318        port: u16,
1319        client_connection_tracker: ClientConnectionTracker,
1320        connection: Option<Connection>,
1321        stream_counter: Arc<ConnectionStreamCounter>,
1322    ) -> Self {
1323        Self {
1324            cancel,
1325            peer_type,
1326            last_update,
1327            port,
1328            _client_connection_tracker: client_connection_tracker,
1329            connection,
1330            stream_counter,
1331        }
1332    }
1333
1334    fn last_update(&self) -> u64 {
1335        self.last_update.load(Ordering::Relaxed)
1336    }
1337
1338    fn stake(&self) -> u64 {
1339        match self.peer_type {
1340            ConnectionPeerType::Unstaked => 0,
1341            ConnectionPeerType::Staked(stake) => stake,
1342        }
1343    }
1344}
1345
1346impl Drop for ConnectionEntry {
1347    fn drop(&mut self) {
1348        if let Some(conn) = self.connection.take() {
1349            conn.close(
1350                CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1351                CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1352            );
1353        }
1354        self.cancel.cancel();
1355    }
1356}
1357
1358#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
1359enum ConnectionTableKey {
1360    IP(IpAddr),
1361    Pubkey(Pubkey),
1362}
1363
1364impl ConnectionTableKey {
1365    fn new(ip: IpAddr, maybe_pubkey: Option<Pubkey>) -> Self {
1366        maybe_pubkey.map_or(ConnectionTableKey::IP(ip), |pubkey| {
1367            ConnectionTableKey::Pubkey(pubkey)
1368        })
1369    }
1370}
1371
1372// Map of IP to list of connection entries
1373struct ConnectionTable {
1374    table: IndexMap<ConnectionTableKey, Vec<ConnectionEntry>>,
1375    total_size: usize,
1376}
1377
1378// Prune the connection which has the oldest update
1379// Return number pruned
1380impl ConnectionTable {
1381    fn new() -> Self {
1382        Self {
1383            table: IndexMap::default(),
1384            total_size: 0,
1385        }
1386    }
1387
1388    fn prune_oldest(&mut self, max_size: usize) -> usize {
1389        let mut num_pruned = 0;
1390        let key = |(_, connections): &(_, &Vec<_>)| {
1391            connections.iter().map(ConnectionEntry::last_update).min()
1392        };
1393        while self.total_size.saturating_sub(num_pruned) > max_size {
1394            match self.table.values().enumerate().min_by_key(key) {
1395                None => break,
1396                Some((index, connections)) => {
1397                    num_pruned += connections.len();
1398                    self.table.swap_remove_index(index);
1399                }
1400            }
1401        }
1402        self.total_size = self.total_size.saturating_sub(num_pruned);
1403        num_pruned
1404    }
1405
1406    // Randomly selects sample_size many connections, evicts the one with the
1407    // lowest stake, and returns the number of pruned connections.
1408    // If the stakes of all the sampled connections are higher than the
1409    // threshold_stake, rejects the pruning attempt, and returns 0.
1410    fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize {
1411        let num_pruned = std::iter::once(self.table.len())
1412            .filter(|&size| size > 0)
1413            .flat_map(|size| {
1414                let mut rng = thread_rng();
1415                repeat_with(move || rng.gen_range(0..size))
1416            })
1417            .map(|index| {
1418                let connection = self.table[index].first();
1419                let stake = connection.map(|connection: &ConnectionEntry| connection.stake());
1420                (index, stake)
1421            })
1422            .take(sample_size)
1423            .min_by_key(|&(_, stake)| stake)
1424            .filter(|&(_, stake)| stake < Some(threshold_stake))
1425            .and_then(|(index, _)| self.table.swap_remove_index(index))
1426            .map(|(_, connections)| connections.len())
1427            .unwrap_or_default();
1428        self.total_size = self.total_size.saturating_sub(num_pruned);
1429        num_pruned
1430    }
1431
1432    fn try_add_connection(
1433        &mut self,
1434        key: ConnectionTableKey,
1435        port: u16,
1436        client_connection_tracker: ClientConnectionTracker,
1437        connection: Option<Connection>,
1438        peer_type: ConnectionPeerType,
1439        last_update: u64,
1440        max_connections_per_peer: usize,
1441    ) -> Option<(
1442        Arc<AtomicU64>,
1443        CancellationToken,
1444        Arc<ConnectionStreamCounter>,
1445    )> {
1446        let connection_entry = self.table.entry(key).or_default();
1447        let has_connection_capacity = connection_entry
1448            .len()
1449            .checked_add(1)
1450            .map(|c| c <= max_connections_per_peer)
1451            .unwrap_or(false);
1452        if has_connection_capacity {
1453            let cancel = CancellationToken::new();
1454            let last_update = Arc::new(AtomicU64::new(last_update));
1455            let stream_counter = connection_entry
1456                .first()
1457                .map(|entry| entry.stream_counter.clone())
1458                .unwrap_or(Arc::new(ConnectionStreamCounter::new()));
1459            connection_entry.push(ConnectionEntry::new(
1460                cancel.clone(),
1461                peer_type,
1462                last_update.clone(),
1463                port,
1464                client_connection_tracker,
1465                connection,
1466                stream_counter.clone(),
1467            ));
1468            self.total_size += 1;
1469            Some((last_update, cancel, stream_counter))
1470        } else {
1471            if let Some(connection) = connection {
1472                connection.close(
1473                    CONNECTION_CLOSE_CODE_TOO_MANY.into(),
1474                    CONNECTION_CLOSE_REASON_TOO_MANY,
1475                );
1476            }
1477            None
1478        }
1479    }
1480
1481    // Returns number of connections that were removed
1482    fn remove_connection(&mut self, key: ConnectionTableKey, port: u16, stable_id: usize) -> usize {
1483        if let Entry::Occupied(mut e) = self.table.entry(key) {
1484            let e_ref = e.get_mut();
1485            let old_size = e_ref.len();
1486
1487            e_ref.retain(|connection_entry| {
1488                // Retain the connection entry if the port is different, or if the connection's
1489                // stable_id doesn't match the provided stable_id.
1490                // (Some unit tests do not fill in a valid connection in the table. To support that,
1491                // if the connection is none, the stable_id check is ignored. i.e. if the port matches,
1492                // the connection gets removed)
1493                connection_entry.port != port
1494                    || connection_entry
1495                        .connection
1496                        .as_ref()
1497                        .and_then(|connection| (connection.stable_id() != stable_id).then_some(0))
1498                        .is_some()
1499            });
1500            let new_size = e_ref.len();
1501            if e_ref.is_empty() {
1502                e.swap_remove_entry();
1503            }
1504            let connections_removed = old_size.saturating_sub(new_size);
1505            self.total_size = self.total_size.saturating_sub(connections_removed);
1506            connections_removed
1507        } else {
1508            0
1509        }
1510    }
1511}
1512
1513struct EndpointAccept<'a> {
1514    endpoint: usize,
1515    accept: Accept<'a>,
1516}
1517
1518impl Future for EndpointAccept<'_> {
1519    type Output = (Option<quinn::Incoming>, usize);
1520
1521    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
1522        let i = self.endpoint;
1523        // Safety:
1524        // self is pinned and accept is a field so it can't get moved out. See safety docs of
1525        // map_unchecked_mut.
1526        unsafe { self.map_unchecked_mut(|this| &mut this.accept) }
1527            .poll(cx)
1528            .map(|r| (r, i))
1529    }
1530}
1531
1532#[cfg(test)]
1533pub mod test {
1534    use {
1535        super::*,
1536        crate::{
1537            nonblocking::{
1538                quic::compute_max_allowed_uni_streams,
1539                testing_utilities::{
1540                    check_multiple_streams, get_client_config, make_client_endpoint,
1541                    setup_quic_server, SpawnTestServerResult, TestServerConfig,
1542                },
1543            },
1544            quic::DEFAULT_TPU_COALESCE,
1545        },
1546        assert_matches::assert_matches,
1547        async_channel::unbounded as async_unbounded,
1548        crossbeam_channel::{unbounded, Receiver},
1549        quinn::{ApplicationClose, ConnectionError},
1550        solana_keypair::Keypair,
1551        solana_net_utils::bind_to_localhost,
1552        solana_signer::Signer,
1553        std::collections::HashMap,
1554        tokio::time::sleep,
1555    };
1556
1557    pub async fn check_timeout(receiver: Receiver<PacketBatch>, server_address: SocketAddr) {
1558        let conn1 = make_client_endpoint(&server_address, None).await;
1559        let total = 30;
1560        for i in 0..total {
1561            let mut s1 = conn1.open_uni().await.unwrap();
1562            s1.write_all(&[0u8]).await.unwrap();
1563            s1.finish().unwrap();
1564            info!("done {}", i);
1565            sleep(Duration::from_millis(1000)).await;
1566        }
1567        let mut received = 0;
1568        loop {
1569            if let Ok(_x) = receiver.try_recv() {
1570                received += 1;
1571                info!("got {}", received);
1572            } else {
1573                sleep(Duration::from_millis(500)).await;
1574            }
1575            if received >= total {
1576                break;
1577            }
1578        }
1579    }
1580
1581    pub async fn check_block_multiple_connections(server_address: SocketAddr) {
1582        let conn1 = make_client_endpoint(&server_address, None).await;
1583        let conn2 = make_client_endpoint(&server_address, None).await;
1584        let mut s1 = conn1.open_uni().await.unwrap();
1585        let s2 = conn2.open_uni().await;
1586        if let Ok(mut s2) = s2 {
1587            s1.write_all(&[0u8]).await.unwrap();
1588            s1.finish().unwrap();
1589            // Send enough data to create more than 1 chunks.
1590            // The first will try to open the connection (which should fail).
1591            // The following chunks will enable the detection of connection failure.
1592            let data = vec![1u8; PACKET_DATA_SIZE * 2];
1593            s2.write_all(&data)
1594                .await
1595                .expect_err("shouldn't be able to open 2 connections");
1596        } else {
1597            // It has been noticed if there is already connection open against the server, this open_uni can fail
1598            // with ApplicationClosed(ApplicationClose) error due to CONNECTION_CLOSE_CODE_TOO_MANY before writing to
1599            // the stream -- expect it.
1600            assert_matches!(s2, Err(quinn::ConnectionError::ApplicationClosed(_)));
1601        }
1602    }
1603
1604    pub async fn check_multiple_writes(
1605        receiver: Receiver<PacketBatch>,
1606        server_address: SocketAddr,
1607        client_keypair: Option<&Keypair>,
1608    ) {
1609        let conn1 = Arc::new(make_client_endpoint(&server_address, client_keypair).await);
1610
1611        // Send a full size packet with single byte writes.
1612        let num_bytes = PACKET_DATA_SIZE;
1613        let num_expected_packets = 1;
1614        let mut s1 = conn1.open_uni().await.unwrap();
1615        for _ in 0..num_bytes {
1616            s1.write_all(&[0u8]).await.unwrap();
1617        }
1618        s1.finish().unwrap();
1619
1620        let mut all_packets = vec![];
1621        let now = Instant::now();
1622        let mut total_packets = 0;
1623        while now.elapsed().as_secs() < 5 {
1624            // We're running in an async environment, we (almost) never
1625            // want to block
1626            if let Ok(packets) = receiver.try_recv() {
1627                total_packets += packets.len();
1628                all_packets.push(packets)
1629            } else {
1630                sleep(Duration::from_secs(1)).await;
1631            }
1632            if total_packets >= num_expected_packets {
1633                break;
1634            }
1635        }
1636        for batch in all_packets {
1637            for p in batch.iter() {
1638                assert_eq!(p.meta().size, num_bytes);
1639            }
1640        }
1641        assert_eq!(total_packets, num_expected_packets);
1642    }
1643
1644    pub async fn check_unstaked_node_connect_failure(server_address: SocketAddr) {
1645        let conn1 = Arc::new(make_client_endpoint(&server_address, None).await);
1646
1647        // Send a full size packet with single byte writes.
1648        if let Ok(mut s1) = conn1.open_uni().await {
1649            for _ in 0..PACKET_DATA_SIZE {
1650                // Ignoring any errors here. s1.finish() will test the error condition
1651                s1.write_all(&[0u8]).await.unwrap_or_default();
1652            }
1653            s1.finish().unwrap_or_default();
1654            s1.stopped().await.unwrap_err();
1655        }
1656    }
1657
1658    #[tokio::test(flavor = "multi_thread")]
1659    async fn test_quic_server_exit() {
1660        let SpawnTestServerResult {
1661            join_handle,
1662            exit,
1663            receiver: _,
1664            server_address: _,
1665            stats: _,
1666        } = setup_quic_server(None, TestServerConfig::default());
1667        exit.store(true, Ordering::Relaxed);
1668        join_handle.await.unwrap();
1669    }
1670
1671    #[tokio::test(flavor = "multi_thread")]
1672    async fn test_quic_timeout() {
1673        solana_logger::setup();
1674        let SpawnTestServerResult {
1675            join_handle,
1676            exit,
1677            receiver,
1678            server_address,
1679            stats: _,
1680        } = setup_quic_server(None, TestServerConfig::default());
1681
1682        check_timeout(receiver, server_address).await;
1683        exit.store(true, Ordering::Relaxed);
1684        join_handle.await.unwrap();
1685    }
1686
1687    #[tokio::test(flavor = "multi_thread")]
1688    async fn test_packet_batcher() {
1689        solana_logger::setup();
1690        let (pkt_batch_sender, pkt_batch_receiver) = unbounded();
1691        let (ptk_sender, pkt_receiver) = async_unbounded();
1692        let exit = Arc::new(AtomicBool::new(false));
1693        let stats = Arc::new(StreamerStats::default());
1694
1695        let handle = tokio::spawn(packet_batch_sender(
1696            pkt_batch_sender,
1697            pkt_receiver,
1698            exit.clone(),
1699            stats,
1700            DEFAULT_TPU_COALESCE,
1701        ));
1702
1703        let num_packets = 1000;
1704
1705        for _i in 0..num_packets {
1706            let mut meta = Meta::default();
1707            let bytes = Bytes::from("Hello world");
1708            let size = bytes.len();
1709            meta.size = size;
1710            let packet_accum = PacketAccumulator {
1711                meta,
1712                chunks: smallvec::smallvec![bytes],
1713                start_time: Instant::now(),
1714            };
1715            ptk_sender.send(packet_accum).await.unwrap();
1716        }
1717        let mut i = 0;
1718        let start = Instant::now();
1719        while i < num_packets && start.elapsed().as_secs() < 2 {
1720            if let Ok(batch) = pkt_batch_receiver.try_recv() {
1721                i += batch.len();
1722            } else {
1723                sleep(Duration::from_millis(1)).await;
1724            }
1725        }
1726        assert_eq!(i, num_packets);
1727        exit.store(true, Ordering::Relaxed);
1728        // Explicit drop to wake up packet_batch_sender
1729        drop(ptk_sender);
1730        handle.await.unwrap();
1731    }
1732
1733    #[tokio::test(flavor = "multi_thread")]
1734    async fn test_quic_stream_timeout() {
1735        solana_logger::setup();
1736        let SpawnTestServerResult {
1737            join_handle,
1738            exit,
1739            receiver: _,
1740            server_address,
1741            stats,
1742        } = setup_quic_server(None, TestServerConfig::default());
1743
1744        let conn1 = make_client_endpoint(&server_address, None).await;
1745        assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1746        assert_eq!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1747
1748        // Send one byte to start the stream
1749        let mut s1 = conn1.open_uni().await.unwrap();
1750        s1.write_all(&[0u8]).await.unwrap_or_default();
1751
1752        // Wait long enough for the stream to timeout in receiving chunks
1753        let sleep_time = DEFAULT_WAIT_FOR_CHUNK_TIMEOUT * 2;
1754        sleep(sleep_time).await;
1755
1756        // Test that the stream was created, but timed out in read
1757        assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
1758        assert_ne!(stats.total_stream_read_timeouts.load(Ordering::Relaxed), 0);
1759
1760        // Test that more writes to the stream will fail (i.e. the stream is no longer writable
1761        // after the timeouts)
1762        assert!(s1.write_all(&[0u8]).await.is_err());
1763
1764        exit.store(true, Ordering::Relaxed);
1765        join_handle.await.unwrap();
1766    }
1767
1768    #[tokio::test(flavor = "multi_thread")]
1769    async fn test_quic_server_block_multiple_connections() {
1770        solana_logger::setup();
1771        let SpawnTestServerResult {
1772            join_handle,
1773            exit,
1774            receiver: _,
1775            server_address,
1776            stats: _,
1777        } = setup_quic_server(None, TestServerConfig::default());
1778        check_block_multiple_connections(server_address).await;
1779        exit.store(true, Ordering::Relaxed);
1780        join_handle.await.unwrap();
1781    }
1782
1783    #[tokio::test(flavor = "multi_thread")]
1784    async fn test_quic_server_multiple_connections_on_single_client_endpoint() {
1785        solana_logger::setup();
1786
1787        let SpawnTestServerResult {
1788            join_handle,
1789            exit,
1790            receiver: _,
1791            server_address,
1792            stats,
1793        } = setup_quic_server(
1794            None,
1795            TestServerConfig {
1796                max_connections_per_peer: 2,
1797                ..Default::default()
1798            },
1799        );
1800
1801        let client_socket = bind_to_localhost().unwrap();
1802        let mut endpoint = quinn::Endpoint::new(
1803            EndpointConfig::default(),
1804            None,
1805            client_socket,
1806            Arc::new(TokioRuntime),
1807        )
1808        .unwrap();
1809        let default_keypair = Keypair::new();
1810        endpoint.set_default_client_config(get_client_config(&default_keypair));
1811        let conn1 = endpoint
1812            .connect(server_address, "localhost")
1813            .expect("Failed in connecting")
1814            .await
1815            .expect("Failed in waiting");
1816
1817        let conn2 = endpoint
1818            .connect(server_address, "localhost")
1819            .expect("Failed in connecting")
1820            .await
1821            .expect("Failed in waiting");
1822
1823        let mut s1 = conn1.open_uni().await.unwrap();
1824        s1.write_all(&[0u8]).await.unwrap();
1825        s1.finish().unwrap();
1826
1827        let mut s2 = conn2.open_uni().await.unwrap();
1828        conn1.close(
1829            CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1830            CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1831        );
1832
1833        let start = Instant::now();
1834        while stats.connection_removed.load(Ordering::Relaxed) != 1 {
1835            debug!("First connection not removed yet");
1836            sleep(Duration::from_millis(10)).await;
1837        }
1838        assert!(start.elapsed().as_secs() < 1);
1839
1840        s2.write_all(&[0u8]).await.unwrap();
1841        s2.finish().unwrap();
1842
1843        conn2.close(
1844            CONNECTION_CLOSE_CODE_DROPPED_ENTRY.into(),
1845            CONNECTION_CLOSE_REASON_DROPPED_ENTRY,
1846        );
1847
1848        let start = Instant::now();
1849        while stats.connection_removed.load(Ordering::Relaxed) != 2 {
1850            debug!("Second connection not removed yet");
1851            sleep(Duration::from_millis(10)).await;
1852        }
1853        assert!(start.elapsed().as_secs() < 1);
1854
1855        exit.store(true, Ordering::Relaxed);
1856        join_handle.await.unwrap();
1857    }
1858
1859    #[tokio::test(flavor = "multi_thread")]
1860    async fn test_quic_server_multiple_writes() {
1861        solana_logger::setup();
1862        let SpawnTestServerResult {
1863            join_handle,
1864            exit,
1865            receiver,
1866            server_address,
1867            stats: _,
1868        } = setup_quic_server(None, TestServerConfig::default());
1869        check_multiple_writes(receiver, server_address, None).await;
1870        exit.store(true, Ordering::Relaxed);
1871        join_handle.await.unwrap();
1872    }
1873
1874    #[tokio::test(flavor = "multi_thread")]
1875    async fn test_quic_server_staked_connection_removal() {
1876        solana_logger::setup();
1877
1878        let client_keypair = Keypair::new();
1879        let stakes = HashMap::from([(client_keypair.pubkey(), 100_000)]);
1880        let staked_nodes = StakedNodes::new(
1881            Arc::new(stakes),
1882            HashMap::<Pubkey, u64>::default(), // overrides
1883        );
1884        let SpawnTestServerResult {
1885            join_handle,
1886            exit,
1887            receiver,
1888            server_address,
1889            stats,
1890        } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1891        check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1892        exit.store(true, Ordering::Relaxed);
1893        join_handle.await.unwrap();
1894        sleep(Duration::from_millis(100)).await;
1895        assert_eq!(
1896            stats
1897                .connection_added_from_unstaked_peer
1898                .load(Ordering::Relaxed),
1899            0
1900        );
1901        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1902        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1903    }
1904
1905    #[tokio::test(flavor = "multi_thread")]
1906    async fn test_quic_server_zero_staked_connection_removal() {
1907        // In this test, the client has a pubkey, but is not in stake table.
1908        solana_logger::setup();
1909
1910        let client_keypair = Keypair::new();
1911        let stakes = HashMap::from([(client_keypair.pubkey(), 0)]);
1912        let staked_nodes = StakedNodes::new(
1913            Arc::new(stakes),
1914            HashMap::<Pubkey, u64>::default(), // overrides
1915        );
1916        let SpawnTestServerResult {
1917            join_handle,
1918            exit,
1919            receiver,
1920            server_address,
1921            stats,
1922        } = setup_quic_server(Some(staked_nodes), TestServerConfig::default());
1923        check_multiple_writes(receiver, server_address, Some(&client_keypair)).await;
1924        exit.store(true, Ordering::Relaxed);
1925        join_handle.await.unwrap();
1926        sleep(Duration::from_millis(100)).await;
1927        assert_eq!(
1928            stats
1929                .connection_added_from_staked_peer
1930                .load(Ordering::Relaxed),
1931            0
1932        );
1933        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1934        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1935    }
1936
1937    #[tokio::test(flavor = "multi_thread")]
1938    async fn test_quic_server_unstaked_connection_removal() {
1939        solana_logger::setup();
1940        let SpawnTestServerResult {
1941            join_handle,
1942            exit,
1943            receiver,
1944            server_address,
1945            stats,
1946        } = setup_quic_server(None, TestServerConfig::default());
1947        check_multiple_writes(receiver, server_address, None).await;
1948        exit.store(true, Ordering::Relaxed);
1949        join_handle.await.unwrap();
1950        sleep(Duration::from_millis(100)).await;
1951        assert_eq!(
1952            stats
1953                .connection_added_from_staked_peer
1954                .load(Ordering::Relaxed),
1955            0
1956        );
1957        assert_eq!(stats.connection_removed.load(Ordering::Relaxed), 1);
1958        assert_eq!(stats.connection_remove_failed.load(Ordering::Relaxed), 0);
1959    }
1960
1961    #[tokio::test(flavor = "multi_thread")]
1962    async fn test_quic_server_unstaked_node_connect_failure() {
1963        solana_logger::setup();
1964        let s = bind_to_localhost().unwrap();
1965        let exit = Arc::new(AtomicBool::new(false));
1966        let (sender, _) = unbounded();
1967        let keypair = Keypair::new();
1968        let server_address = s.local_addr().unwrap();
1969        let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
1970        let SpawnNonBlockingServerResult {
1971            endpoints: _,
1972            stats: _,
1973            thread: t,
1974            max_concurrent_connections: _,
1975        } = spawn_server(
1976            "quic_streamer_test",
1977            s,
1978            &keypair,
1979            sender,
1980            exit.clone(),
1981            staked_nodes,
1982            QuicServerParams {
1983                max_unstaked_connections: 0, // Do not allow any connection from unstaked clients/nodes
1984                coalesce_channel_size: 100_000, // smaller channel size for faster test
1985                ..QuicServerParams::default()
1986            },
1987        )
1988        .unwrap();
1989
1990        check_unstaked_node_connect_failure(server_address).await;
1991        exit.store(true, Ordering::Relaxed);
1992        t.await.unwrap();
1993    }
1994
1995    #[tokio::test(flavor = "multi_thread")]
1996    async fn test_quic_server_multiple_streams() {
1997        solana_logger::setup();
1998        let s = bind_to_localhost().unwrap();
1999        let exit = Arc::new(AtomicBool::new(false));
2000        let (sender, receiver) = unbounded();
2001        let keypair = Keypair::new();
2002        let server_address = s.local_addr().unwrap();
2003        let staked_nodes = Arc::new(RwLock::new(StakedNodes::default()));
2004        let SpawnNonBlockingServerResult {
2005            endpoints: _,
2006            stats,
2007            thread: t,
2008            max_concurrent_connections: _,
2009        } = spawn_server(
2010            "quic_streamer_test",
2011            s,
2012            &keypair,
2013            sender,
2014            exit.clone(),
2015            staked_nodes,
2016            QuicServerParams {
2017                max_connections_per_peer: 2,
2018                coalesce_channel_size: 100_000, // smaller channel size for faster test
2019                ..QuicServerParams::default()
2020            },
2021        )
2022        .unwrap();
2023
2024        check_multiple_streams(receiver, server_address, None).await;
2025        assert_eq!(stats.total_streams.load(Ordering::Relaxed), 0);
2026        assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20);
2027        assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2);
2028        assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2029        exit.store(true, Ordering::Relaxed);
2030        t.await.unwrap();
2031        assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0);
2032        assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2);
2033    }
2034
2035    #[test]
2036    fn test_prune_table_with_ip() {
2037        use std::net::Ipv4Addr;
2038        solana_logger::setup();
2039        let mut table = ConnectionTable::new();
2040        let mut num_entries = 5;
2041        let max_connections_per_peer = 10;
2042        let sockets: Vec<_> = (0..num_entries)
2043            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2044            .collect();
2045        let stats = Arc::new(StreamerStats::default());
2046        for (i, socket) in sockets.iter().enumerate() {
2047            table
2048                .try_add_connection(
2049                    ConnectionTableKey::IP(socket.ip()),
2050                    socket.port(),
2051                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2052                    None,
2053                    ConnectionPeerType::Unstaked,
2054                    i as u64,
2055                    max_connections_per_peer,
2056                )
2057                .unwrap();
2058        }
2059        num_entries += 1;
2060        table
2061            .try_add_connection(
2062                ConnectionTableKey::IP(sockets[0].ip()),
2063                sockets[0].port(),
2064                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2065                None,
2066                ConnectionPeerType::Unstaked,
2067                5,
2068                max_connections_per_peer,
2069            )
2070            .unwrap();
2071
2072        let new_size = 3;
2073        let pruned = table.prune_oldest(new_size);
2074        assert_eq!(pruned, num_entries as usize - new_size);
2075        for v in table.table.values() {
2076            for x in v {
2077                assert!((x.last_update() + 1) >= (num_entries as u64 - new_size as u64));
2078            }
2079        }
2080        assert_eq!(table.table.len(), new_size);
2081        assert_eq!(table.total_size, new_size);
2082        for socket in sockets.iter().take(num_entries as usize).skip(new_size - 1) {
2083            table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2084        }
2085        assert_eq!(table.total_size, 0);
2086        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2087    }
2088
2089    #[test]
2090    fn test_prune_table_with_unique_pubkeys() {
2091        solana_logger::setup();
2092        let mut table = ConnectionTable::new();
2093
2094        // We should be able to add more entries than max_connections_per_peer, since each entry is
2095        // from a different peer pubkey.
2096        let num_entries = 15;
2097        let max_connections_per_peer = 10;
2098        let stats = Arc::new(StreamerStats::default());
2099
2100        let pubkeys: Vec<_> = (0..num_entries).map(|_| Pubkey::new_unique()).collect();
2101        for (i, pubkey) in pubkeys.iter().enumerate() {
2102            table
2103                .try_add_connection(
2104                    ConnectionTableKey::Pubkey(*pubkey),
2105                    0,
2106                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2107                    None,
2108                    ConnectionPeerType::Unstaked,
2109                    i as u64,
2110                    max_connections_per_peer,
2111                )
2112                .unwrap();
2113        }
2114
2115        let new_size = 3;
2116        let pruned = table.prune_oldest(new_size);
2117        assert_eq!(pruned, num_entries as usize - new_size);
2118        assert_eq!(table.table.len(), new_size);
2119        assert_eq!(table.total_size, new_size);
2120        for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) {
2121            table.remove_connection(ConnectionTableKey::Pubkey(*pubkey), 0, 0);
2122        }
2123        assert_eq!(table.total_size, 0);
2124        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2125    }
2126
2127    #[test]
2128    fn test_prune_table_with_non_unique_pubkeys() {
2129        solana_logger::setup();
2130        let mut table = ConnectionTable::new();
2131
2132        let max_connections_per_peer = 10;
2133        let pubkey = Pubkey::new_unique();
2134        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2135
2136        (0..max_connections_per_peer).for_each(|i| {
2137            table
2138                .try_add_connection(
2139                    ConnectionTableKey::Pubkey(pubkey),
2140                    0,
2141                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2142                    None,
2143                    ConnectionPeerType::Unstaked,
2144                    i as u64,
2145                    max_connections_per_peer,
2146                )
2147                .unwrap();
2148        });
2149
2150        // We should NOT be able to add more entries than max_connections_per_peer, since we are
2151        // using the same peer pubkey.
2152        assert!(table
2153            .try_add_connection(
2154                ConnectionTableKey::Pubkey(pubkey),
2155                0,
2156                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2157                None,
2158                ConnectionPeerType::Unstaked,
2159                10,
2160                max_connections_per_peer,
2161            )
2162            .is_none());
2163
2164        // We should be able to add an entry from another peer pubkey
2165        let num_entries = max_connections_per_peer + 1;
2166        let pubkey2 = Pubkey::new_unique();
2167        assert!(table
2168            .try_add_connection(
2169                ConnectionTableKey::Pubkey(pubkey2),
2170                0,
2171                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2172                None,
2173                ConnectionPeerType::Unstaked,
2174                10,
2175                max_connections_per_peer,
2176            )
2177            .is_some());
2178
2179        assert_eq!(table.total_size, num_entries);
2180
2181        let new_max_size = 3;
2182        let pruned = table.prune_oldest(new_max_size);
2183        assert!(pruned >= num_entries - new_max_size);
2184        assert!(table.table.len() <= new_max_size);
2185        assert!(table.total_size <= new_max_size);
2186
2187        table.remove_connection(ConnectionTableKey::Pubkey(pubkey2), 0, 0);
2188        assert_eq!(table.total_size, 0);
2189        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2190    }
2191
2192    #[test]
2193    fn test_prune_table_random() {
2194        use std::net::Ipv4Addr;
2195        solana_logger::setup();
2196        let mut table = ConnectionTable::new();
2197        let num_entries = 5;
2198        let max_connections_per_peer = 10;
2199        let sockets: Vec<_> = (0..num_entries)
2200            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2201            .collect();
2202        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2203
2204        for (i, socket) in sockets.iter().enumerate() {
2205            table
2206                .try_add_connection(
2207                    ConnectionTableKey::IP(socket.ip()),
2208                    socket.port(),
2209                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2210                    None,
2211                    ConnectionPeerType::Staked((i + 1) as u64),
2212                    i as u64,
2213                    max_connections_per_peer,
2214                )
2215                .unwrap();
2216        }
2217
2218        // Try pruninng with threshold stake less than all the entries in the table
2219        // It should fail to prune (i.e. return 0 number of pruned entries)
2220        let pruned = table.prune_random(/*sample_size:*/ 2, /*threshold_stake:*/ 0);
2221        assert_eq!(pruned, 0);
2222
2223        // Try pruninng with threshold stake higher than all the entries in the table
2224        // It should succeed to prune (i.e. return 1 number of pruned entries)
2225        let pruned = table.prune_random(
2226            2,                      // sample_size
2227            num_entries as u64 + 1, // threshold_stake
2228        );
2229        assert_eq!(pruned, 1);
2230        // We had 5 connections and pruned 1, we should have 4 left
2231        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 4);
2232    }
2233
2234    #[test]
2235    fn test_remove_connections() {
2236        use std::net::Ipv4Addr;
2237        solana_logger::setup();
2238        let mut table = ConnectionTable::new();
2239        let num_ips = 5;
2240        let max_connections_per_peer = 10;
2241        let mut sockets: Vec<_> = (0..num_ips)
2242            .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::new(i, 0, 0, 0)), 0))
2243            .collect();
2244        let stats: Arc<StreamerStats> = Arc::new(StreamerStats::default());
2245
2246        for (i, socket) in sockets.iter().enumerate() {
2247            table
2248                .try_add_connection(
2249                    ConnectionTableKey::IP(socket.ip()),
2250                    socket.port(),
2251                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2252                    None,
2253                    ConnectionPeerType::Unstaked,
2254                    (i * 2) as u64,
2255                    max_connections_per_peer,
2256                )
2257                .unwrap();
2258
2259            table
2260                .try_add_connection(
2261                    ConnectionTableKey::IP(socket.ip()),
2262                    socket.port(),
2263                    ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2264                    None,
2265                    ConnectionPeerType::Unstaked,
2266                    (i * 2 + 1) as u64,
2267                    max_connections_per_peer,
2268                )
2269                .unwrap();
2270        }
2271
2272        let single_connection_addr =
2273            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips, 0, 0, 0)), 0);
2274        table
2275            .try_add_connection(
2276                ConnectionTableKey::IP(single_connection_addr.ip()),
2277                single_connection_addr.port(),
2278                ClientConnectionTracker::new(stats.clone(), 1000).unwrap(),
2279                None,
2280                ConnectionPeerType::Unstaked,
2281                (num_ips * 2) as u64,
2282                max_connections_per_peer,
2283            )
2284            .unwrap();
2285
2286        let zero_connection_addr =
2287            SocketAddr::new(IpAddr::V4(Ipv4Addr::new(num_ips + 1, 0, 0, 0)), 0);
2288
2289        sockets.push(single_connection_addr);
2290        sockets.push(zero_connection_addr);
2291
2292        for socket in sockets.iter() {
2293            table.remove_connection(ConnectionTableKey::IP(socket.ip()), socket.port(), 0);
2294        }
2295        assert_eq!(table.total_size, 0);
2296        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2297    }
2298
2299    #[test]
2300
2301    fn test_max_allowed_uni_streams() {
2302        assert_eq!(
2303            compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 0),
2304            QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2305        );
2306        assert_eq!(
2307            compute_max_allowed_uni_streams(ConnectionPeerType::Staked(10), 0),
2308            QUIC_MIN_STAKED_CONCURRENT_STREAMS
2309        );
2310        let delta =
2311            (QUIC_TOTAL_STAKED_CONCURRENT_STREAMS - QUIC_MIN_STAKED_CONCURRENT_STREAMS) as f64;
2312        assert_eq!(
2313            compute_max_allowed_uni_streams(ConnectionPeerType::Staked(1000), 10000),
2314            QUIC_MAX_STAKED_CONCURRENT_STREAMS,
2315        );
2316        assert_eq!(
2317            compute_max_allowed_uni_streams(ConnectionPeerType::Staked(100), 10000),
2318            ((delta / (100_f64)) as usize + QUIC_MIN_STAKED_CONCURRENT_STREAMS)
2319                .min(QUIC_MAX_STAKED_CONCURRENT_STREAMS)
2320        );
2321        assert_eq!(
2322            compute_max_allowed_uni_streams(ConnectionPeerType::Unstaked, 10000),
2323            QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS
2324        );
2325    }
2326
2327    #[test]
2328    fn test_cacluate_receive_window_ratio_for_staked_node() {
2329        let mut max_stake = 10000;
2330        let mut min_stake = 0;
2331        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, min_stake);
2332        assert_eq!(ratio, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO);
2333
2334        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2335        let max_ratio = QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO;
2336        assert_eq!(ratio, max_ratio);
2337
2338        let ratio =
2339            compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake / 2);
2340        let average_ratio =
2341            (QUIC_MAX_STAKED_RECEIVE_WINDOW_RATIO + QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO) / 2;
2342        assert_eq!(ratio, average_ratio);
2343
2344        max_stake = 10000;
2345        min_stake = 10000;
2346        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2347        assert_eq!(ratio, max_ratio);
2348
2349        max_stake = 0;
2350        min_stake = 0;
2351        let ratio = compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake);
2352        assert_eq!(ratio, max_ratio);
2353
2354        max_stake = 1000;
2355        min_stake = 10;
2356        let ratio =
2357            compute_receive_window_ratio_for_staked_node(max_stake, min_stake, max_stake + 10);
2358        assert_eq!(ratio, max_ratio);
2359    }
2360
2361    #[tokio::test(flavor = "multi_thread")]
2362    async fn test_throttling_check_no_packet_drop() {
2363        solana_logger::setup_with_default_filter();
2364
2365        let SpawnTestServerResult {
2366            join_handle,
2367            exit,
2368            receiver,
2369            server_address,
2370            stats,
2371        } = setup_quic_server(None, TestServerConfig::default());
2372
2373        let client_connection = make_client_endpoint(&server_address, None).await;
2374
2375        // unstaked connection can handle up to 100tps, so we should send in ~1s.
2376        let expected_num_txs = 100;
2377        let start_time = tokio::time::Instant::now();
2378        for i in 0..expected_num_txs {
2379            let mut send_stream = client_connection.open_uni().await.unwrap();
2380            let data = format!("{i}").into_bytes();
2381            send_stream.write_all(&data).await.unwrap();
2382            send_stream.finish().unwrap();
2383        }
2384        let elapsed_sending: f64 = start_time.elapsed().as_secs_f64();
2385        info!("Elapsed sending: {elapsed_sending}");
2386
2387        // check that delivered all of them
2388        let start_time = tokio::time::Instant::now();
2389        let mut num_txs_received = 0;
2390        while num_txs_received < expected_num_txs && start_time.elapsed() < Duration::from_secs(2) {
2391            if let Ok(packets) = receiver.try_recv() {
2392                num_txs_received += packets.len();
2393            } else {
2394                sleep(Duration::from_millis(100)).await;
2395            }
2396        }
2397        assert_eq!(expected_num_txs, num_txs_received);
2398
2399        // stop it
2400        exit.store(true, Ordering::Relaxed);
2401        join_handle.await.unwrap();
2402
2403        assert_eq!(
2404            stats.total_new_streams.load(Ordering::Relaxed),
2405            expected_num_txs
2406        );
2407        assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0);
2408    }
2409
2410    #[test]
2411    fn test_client_connection_tracker() {
2412        let stats = Arc::new(StreamerStats::default());
2413        let tracker_1 = ClientConnectionTracker::new(stats.clone(), 1);
2414        assert!(tracker_1.is_ok());
2415        assert!(ClientConnectionTracker::new(stats.clone(), 1).is_err());
2416        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 1);
2417        // dropping the connection, concurrent connections should become 0
2418        drop(tracker_1);
2419        assert_eq!(stats.open_connections.load(Ordering::Relaxed), 0);
2420    }
2421
2422    #[tokio::test(flavor = "multi_thread")]
2423    async fn test_client_connection_close_invalid_stream() {
2424        let SpawnTestServerResult {
2425            join_handle,
2426            server_address,
2427            stats,
2428            exit,
2429            ..
2430        } = setup_quic_server(None, TestServerConfig::default());
2431
2432        let client_connection = make_client_endpoint(&server_address, None).await;
2433
2434        let mut send_stream = client_connection.open_uni().await.unwrap();
2435        send_stream
2436            .write_all(&[42; PACKET_DATA_SIZE + 1])
2437            .await
2438            .unwrap();
2439        match client_connection.closed().await {
2440            ConnectionError::ApplicationClosed(ApplicationClose { error_code, reason }) => {
2441                assert_eq!(error_code, CONNECTION_CLOSE_CODE_INVALID_STREAM.into());
2442                assert_eq!(reason, CONNECTION_CLOSE_REASON_INVALID_STREAM);
2443            }
2444            _ => panic!("unexpected close"),
2445        }
2446        assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1);
2447        exit.store(true, Ordering::Relaxed);
2448        join_handle.await.unwrap();
2449    }
2450}