solana_streamer/nonblocking/
testing_utilities.rs

1//! Contains utility functions to create server and client for test purposes.
2use {
3    super::quic::{
4        spawn_server_multi, SpawnNonBlockingServerResult, ALPN_TPU_PROTOCOL_ID,
5        DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE, DEFAULT_MAX_STREAMS_PER_MS,
6        DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
7    },
8    crate::{
9        quic::{StreamerStats, MAX_STAKED_CONNECTIONS, MAX_UNSTAKED_CONNECTIONS},
10        streamer::StakedNodes,
11        tls_certificates::new_dummy_x509_certificate,
12    },
13    crossbeam_channel::unbounded,
14    quinn::{
15        crypto::rustls::QuicClientConfig, ClientConfig, Connection, EndpointConfig, IdleTimeout,
16        TokioRuntime, TransportConfig,
17    },
18    solana_perf::packet::PacketBatch,
19    solana_sdk::{
20        net::DEFAULT_TPU_COALESCE,
21        quic::{QUIC_KEEP_ALIVE, QUIC_MAX_TIMEOUT, QUIC_SEND_FAIRNESS},
22        signer::keypair::Keypair,
23    },
24    std::{
25        net::{SocketAddr, UdpSocket},
26        sync::{atomic::AtomicBool, Arc, RwLock},
27    },
28    tokio::task::JoinHandle,
29};
30
31#[derive(Debug)]
32pub struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
33
34impl SkipServerVerification {
35    pub fn new() -> Arc<Self> {
36        Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
37    }
38}
39
40impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
41    fn verify_tls12_signature(
42        &self,
43        message: &[u8],
44        cert: &rustls::pki_types::CertificateDer<'_>,
45        dss: &rustls::DigitallySignedStruct,
46    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
47        rustls::crypto::verify_tls12_signature(
48            message,
49            cert,
50            dss,
51            &self.0.signature_verification_algorithms,
52        )
53    }
54
55    fn verify_tls13_signature(
56        &self,
57        message: &[u8],
58        cert: &rustls::pki_types::CertificateDer<'_>,
59        dss: &rustls::DigitallySignedStruct,
60    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
61        rustls::crypto::verify_tls13_signature(
62            message,
63            cert,
64            dss,
65            &self.0.signature_verification_algorithms,
66        )
67    }
68
69    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
70        self.0.signature_verification_algorithms.supported_schemes()
71    }
72
73    fn verify_server_cert(
74        &self,
75        _end_entity: &rustls::pki_types::CertificateDer<'_>,
76        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
77        _server_name: &rustls::pki_types::ServerName<'_>,
78        _ocsp_response: &[u8],
79        _now: rustls::pki_types::UnixTime,
80    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
81        Ok(rustls::client::danger::ServerCertVerified::assertion())
82    }
83}
84
85pub fn get_client_config(keypair: &Keypair) -> ClientConfig {
86    let (cert, key) = new_dummy_x509_certificate(keypair);
87
88    let mut crypto = rustls::ClientConfig::builder()
89        .dangerous()
90        .with_custom_certificate_verifier(SkipServerVerification::new())
91        .with_client_auth_cert(vec![cert], key)
92        .expect("Failed to use client certificate");
93
94    crypto.enable_early_data = true;
95    crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()];
96
97    let mut config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto).unwrap()));
98
99    let mut transport_config = TransportConfig::default();
100    let timeout = IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap();
101    transport_config.max_idle_timeout(Some(timeout));
102    transport_config.keep_alive_interval(Some(QUIC_KEEP_ALIVE));
103    transport_config.send_fairness(QUIC_SEND_FAIRNESS);
104    config.transport_config(Arc::new(transport_config));
105
106    config
107}
108
109#[derive(Debug, Clone)]
110pub struct TestServerConfig {
111    pub max_connections_per_peer: usize,
112    pub max_staked_connections: usize,
113    pub max_unstaked_connections: usize,
114    pub max_streams_per_ms: u64,
115    pub max_connections_per_ipaddr_per_minute: u64,
116}
117
118impl Default for TestServerConfig {
119    fn default() -> Self {
120        Self {
121            max_connections_per_peer: 1,
122            max_staked_connections: MAX_STAKED_CONNECTIONS,
123            max_unstaked_connections: MAX_UNSTAKED_CONNECTIONS,
124            max_streams_per_ms: DEFAULT_MAX_STREAMS_PER_MS,
125            max_connections_per_ipaddr_per_minute: DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE,
126        }
127    }
128}
129
130pub struct SpawnTestServerResult {
131    pub join_handle: JoinHandle<()>,
132    pub exit: Arc<AtomicBool>,
133    pub receiver: crossbeam_channel::Receiver<PacketBatch>,
134    pub server_address: SocketAddr,
135    pub stats: Arc<StreamerStats>,
136}
137
138pub fn setup_quic_server(
139    option_staked_nodes: Option<StakedNodes>,
140    config: TestServerConfig,
141) -> SpawnTestServerResult {
142    let sockets = {
143        #[cfg(not(target_os = "windows"))]
144        {
145            use std::{
146                os::fd::{FromRawFd, IntoRawFd},
147                str::FromStr as _,
148            };
149            (0..10)
150                .map(|_| {
151                    let sock = socket2::Socket::new(
152                        socket2::Domain::IPV4,
153                        socket2::Type::DGRAM,
154                        Some(socket2::Protocol::UDP),
155                    )
156                    .unwrap();
157                    sock.set_reuse_port(true).unwrap();
158                    sock.bind(&SocketAddr::from_str("127.0.0.1:0").unwrap().into())
159                        .unwrap();
160                    unsafe { UdpSocket::from_raw_fd(sock.into_raw_fd()) }
161                })
162                .collect::<Vec<_>>()
163        }
164        #[cfg(target_os = "windows")]
165        {
166            vec![UdpSocket::bind("127.0.0.1:0").unwrap()]
167        }
168    };
169    setup_quic_server_with_sockets(sockets, option_staked_nodes, config)
170}
171
172pub fn setup_quic_server_with_sockets(
173    sockets: Vec<UdpSocket>,
174    option_staked_nodes: Option<StakedNodes>,
175    TestServerConfig {
176        max_connections_per_peer,
177        max_staked_connections,
178        max_unstaked_connections,
179        max_streams_per_ms,
180        max_connections_per_ipaddr_per_minute,
181    }: TestServerConfig,
182) -> SpawnTestServerResult {
183    let exit = Arc::new(AtomicBool::new(false));
184    let (sender, receiver) = unbounded();
185    let keypair = Keypair::new();
186    let server_address = sockets[0].local_addr().unwrap();
187    let staked_nodes = Arc::new(RwLock::new(option_staked_nodes.unwrap_or_default()));
188    let SpawnNonBlockingServerResult {
189        endpoints: _,
190        stats,
191        thread: handle,
192        max_concurrent_connections: _,
193    } = spawn_server_multi(
194        "quic_streamer_test",
195        sockets,
196        &keypair,
197        sender,
198        exit.clone(),
199        max_connections_per_peer,
200        staked_nodes,
201        max_staked_connections,
202        max_unstaked_connections,
203        max_streams_per_ms,
204        max_connections_per_ipaddr_per_minute,
205        DEFAULT_WAIT_FOR_CHUNK_TIMEOUT,
206        DEFAULT_TPU_COALESCE,
207    )
208    .unwrap();
209    SpawnTestServerResult {
210        join_handle: handle,
211        exit,
212        receiver,
213        server_address,
214        stats,
215    }
216}
217
218pub async fn make_client_endpoint(
219    addr: &SocketAddr,
220    client_keypair: Option<&Keypair>,
221) -> Connection {
222    let client_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
223    let mut endpoint = quinn::Endpoint::new(
224        EndpointConfig::default(),
225        None,
226        client_socket,
227        Arc::new(TokioRuntime),
228    )
229    .unwrap();
230    let default_keypair = Keypair::new();
231    endpoint.set_default_client_config(get_client_config(
232        client_keypair.unwrap_or(&default_keypair),
233    ));
234    endpoint
235        .connect(*addr, "localhost")
236        .expect("Endpoint configuration should be correct")
237        .await
238        .expect("Test server should be already listening on 'localhost'")
239}