solana_quic_client/nonblocking/
quic_client.rs

1//! Simple nonblocking client that connects to a given UDP port with the QUIC protocol
2//! and provides an interface for sending data which is restricted by the
3//! server's flow control.
4use {
5    async_lock::Mutex,
6    async_trait::async_trait,
7    futures::future::TryFutureExt,
8    log::*,
9    quinn::{
10        crypto::rustls::QuicClientConfig, ClientConfig, ClosedStream, ConnectError, Connection,
11        ConnectionError, Endpoint, EndpointConfig, IdleTimeout, TokioRuntime, TransportConfig,
12        WriteError,
13    },
14    solana_connection_cache::{
15        client_connection::ClientStats, connection_cache_stats::ConnectionCacheStats,
16        nonblocking::client_connection::ClientConnection,
17    },
18    solana_measure::measure::Measure,
19    solana_net_utils::VALIDATOR_PORT_RANGE,
20    solana_rpc_client_api::client_error::ErrorKind as ClientErrorKind,
21    solana_sdk::{
22        quic::{
23            QUIC_CONNECTION_HANDSHAKE_TIMEOUT, QUIC_KEEP_ALIVE, QUIC_MAX_TIMEOUT,
24            QUIC_SEND_FAIRNESS,
25        },
26        signature::Keypair,
27        transport::Result as TransportResult,
28    },
29    solana_streamer::{
30        nonblocking::quic::ALPN_TPU_PROTOCOL_ID, tls_certificates::new_dummy_x509_certificate,
31    },
32    std::{
33        net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
34        sync::{atomic::Ordering, Arc},
35        thread,
36    },
37    thiserror::Error,
38    tokio::{sync::OnceCell, time::timeout},
39};
40
41#[derive(Debug)]
42pub struct SkipServerVerification(Arc<rustls::crypto::CryptoProvider>);
43
44impl SkipServerVerification {
45    pub fn new() -> Arc<Self> {
46        Arc::new(Self(Arc::new(rustls::crypto::ring::default_provider())))
47    }
48}
49
50impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
51    fn verify_tls12_signature(
52        &self,
53        message: &[u8],
54        cert: &rustls::pki_types::CertificateDer<'_>,
55        dss: &rustls::DigitallySignedStruct,
56    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
57        rustls::crypto::verify_tls12_signature(
58            message,
59            cert,
60            dss,
61            &self.0.signature_verification_algorithms,
62        )
63    }
64
65    fn verify_tls13_signature(
66        &self,
67        message: &[u8],
68        cert: &rustls::pki_types::CertificateDer<'_>,
69        dss: &rustls::DigitallySignedStruct,
70    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
71        rustls::crypto::verify_tls13_signature(
72            message,
73            cert,
74            dss,
75            &self.0.signature_verification_algorithms,
76        )
77    }
78
79    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
80        self.0.signature_verification_algorithms.supported_schemes()
81    }
82
83    fn verify_server_cert(
84        &self,
85        _end_entity: &rustls::pki_types::CertificateDer<'_>,
86        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
87        _server_name: &rustls::pki_types::ServerName<'_>,
88        _ocsp_response: &[u8],
89        _now: rustls::pki_types::UnixTime,
90    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
91        Ok(rustls::client::danger::ServerCertVerified::assertion())
92    }
93}
94
95pub struct QuicClientCertificate {
96    pub certificate: rustls::pki_types::CertificateDer<'static>,
97    pub key: rustls::pki_types::PrivateKeyDer<'static>,
98}
99
100/// A lazy-initialized Quic Endpoint
101pub struct QuicLazyInitializedEndpoint {
102    endpoint: OnceCell<Arc<Endpoint>>,
103    client_certificate: Arc<QuicClientCertificate>,
104    client_endpoint: Option<Endpoint>,
105}
106
107#[derive(Error, Debug)]
108pub enum QuicError {
109    #[error(transparent)]
110    WriteError(#[from] WriteError),
111    #[error(transparent)]
112    ConnectionError(#[from] ConnectionError),
113    #[error(transparent)]
114    ConnectError(#[from] ConnectError),
115    #[error(transparent)]
116    ClosedStream(#[from] ClosedStream),
117}
118
119impl From<QuicError> for ClientErrorKind {
120    fn from(quic_error: QuicError) -> Self {
121        Self::Custom(format!("{quic_error:?}"))
122    }
123}
124
125impl QuicLazyInitializedEndpoint {
126    pub fn new(
127        client_certificate: Arc<QuicClientCertificate>,
128        client_endpoint: Option<Endpoint>,
129    ) -> Self {
130        Self {
131            endpoint: OnceCell::<Arc<Endpoint>>::new(),
132            client_certificate,
133            client_endpoint,
134        }
135    }
136
137    fn create_endpoint(&self) -> Endpoint {
138        let mut endpoint = if let Some(endpoint) = &self.client_endpoint {
139            endpoint.clone()
140        } else {
141            let client_socket = solana_net_utils::bind_in_range(
142                IpAddr::V4(Ipv4Addr::UNSPECIFIED),
143                VALIDATOR_PORT_RANGE,
144            )
145            .expect("QuicLazyInitializedEndpoint::create_endpoint bind_in_range")
146            .1;
147
148            QuicNewConnection::create_endpoint(EndpointConfig::default(), client_socket)
149        };
150
151        let mut crypto = rustls::ClientConfig::builder()
152            .dangerous()
153            .with_custom_certificate_verifier(SkipServerVerification::new())
154            .with_client_auth_cert(
155                vec![self.client_certificate.certificate.clone()],
156                self.client_certificate.key.clone_key(),
157            )
158            .expect("Failed to set QUIC client certificates");
159        crypto.enable_early_data = true;
160        crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()];
161
162        let mut config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(crypto).unwrap()));
163        let mut transport_config = TransportConfig::default();
164
165        let timeout = IdleTimeout::try_from(QUIC_MAX_TIMEOUT).unwrap();
166        transport_config.max_idle_timeout(Some(timeout));
167        transport_config.keep_alive_interval(Some(QUIC_KEEP_ALIVE));
168        transport_config.send_fairness(QUIC_SEND_FAIRNESS);
169        config.transport_config(Arc::new(transport_config));
170
171        endpoint.set_default_client_config(config);
172
173        endpoint
174    }
175
176    async fn get_endpoint(&self) -> Arc<Endpoint> {
177        self.endpoint
178            .get_or_init(|| async { Arc::new(self.create_endpoint()) })
179            .await
180            .clone()
181    }
182}
183
184impl Default for QuicLazyInitializedEndpoint {
185    fn default() -> Self {
186        let (cert, priv_key) = new_dummy_x509_certificate(&Keypair::new());
187        Self::new(
188            Arc::new(QuicClientCertificate {
189                certificate: cert,
190                key: priv_key,
191            }),
192            None,
193        )
194    }
195}
196
197/// A wrapper over NewConnection with additional capability to create the endpoint as part
198/// of creating a new connection.
199#[derive(Clone)]
200struct QuicNewConnection {
201    endpoint: Arc<Endpoint>,
202    connection: Arc<Connection>,
203}
204
205impl QuicNewConnection {
206    /// Create a QuicNewConnection given the remote address 'addr'.
207    async fn make_connection(
208        endpoint: Arc<QuicLazyInitializedEndpoint>,
209        addr: SocketAddr,
210        stats: &ClientStats,
211    ) -> Result<Self, QuicError> {
212        let mut make_connection_measure = Measure::start("make_connection_measure");
213        let endpoint = endpoint.get_endpoint().await;
214
215        let connecting = endpoint.connect(addr, "connect")?;
216        stats.total_connections.fetch_add(1, Ordering::Relaxed);
217        if let Ok(connecting_result) = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await
218        {
219            if connecting_result.is_err() {
220                stats.connection_errors.fetch_add(1, Ordering::Relaxed);
221            }
222            make_connection_measure.stop();
223            stats
224                .make_connection_ms
225                .fetch_add(make_connection_measure.as_ms(), Ordering::Relaxed);
226
227            let connection = connecting_result?;
228
229            Ok(Self {
230                endpoint,
231                connection: Arc::new(connection),
232            })
233        } else {
234            Err(ConnectionError::TimedOut.into())
235        }
236    }
237
238    fn create_endpoint(config: EndpointConfig, client_socket: UdpSocket) -> Endpoint {
239        quinn::Endpoint::new(config, None, client_socket, Arc::new(TokioRuntime))
240            .expect("QuicNewConnection::create_endpoint quinn::Endpoint::new")
241    }
242
243    // Attempts to make a faster connection by taking advantage of pre-existing key material.
244    // Only works if connection to this endpoint was previously established.
245    async fn make_connection_0rtt(
246        &mut self,
247        addr: SocketAddr,
248        stats: &ClientStats,
249    ) -> Result<Arc<Connection>, QuicError> {
250        let connecting = self.endpoint.connect(addr, "connect")?;
251        stats.total_connections.fetch_add(1, Ordering::Relaxed);
252        let connection = match connecting.into_0rtt() {
253            Ok((connection, zero_rtt)) => {
254                if let Ok(zero_rtt) = timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, zero_rtt).await {
255                    if zero_rtt {
256                        stats.zero_rtt_accepts.fetch_add(1, Ordering::Relaxed);
257                    } else {
258                        stats.zero_rtt_rejects.fetch_add(1, Ordering::Relaxed);
259                    }
260                    connection
261                } else {
262                    return Err(ConnectionError::TimedOut.into());
263                }
264            }
265            Err(connecting) => {
266                stats.connection_errors.fetch_add(1, Ordering::Relaxed);
267
268                if let Ok(connecting_result) =
269                    timeout(QUIC_CONNECTION_HANDSHAKE_TIMEOUT, connecting).await
270                {
271                    connecting_result?
272                } else {
273                    return Err(ConnectionError::TimedOut.into());
274                }
275            }
276        };
277        self.connection = Arc::new(connection);
278        Ok(self.connection.clone())
279    }
280}
281
282pub struct QuicClient {
283    endpoint: Arc<QuicLazyInitializedEndpoint>,
284    connection: Arc<Mutex<Option<QuicNewConnection>>>,
285    addr: SocketAddr,
286    stats: Arc<ClientStats>,
287}
288
289impl QuicClient {
290    pub fn new(endpoint: Arc<QuicLazyInitializedEndpoint>, addr: SocketAddr) -> Self {
291        Self {
292            endpoint,
293            connection: Arc::new(Mutex::new(None)),
294            addr,
295            stats: Arc::new(ClientStats::default()),
296        }
297    }
298
299    async fn _send_buffer_using_conn(
300        data: &[u8],
301        connection: &Connection,
302    ) -> Result<(), QuicError> {
303        let mut send_stream = connection.open_uni().await?;
304        send_stream.write_all(data).await?;
305        Ok(())
306    }
307
308    // Attempts to send data, connecting/reconnecting as necessary
309    // On success, returns the connection used to successfully send the data
310    async fn _send_buffer(
311        &self,
312        data: &[u8],
313        stats: &ClientStats,
314        connection_stats: Arc<ConnectionCacheStats>,
315    ) -> Result<Arc<Connection>, QuicError> {
316        let mut measure_send_packet = Measure::start("send_packet_us");
317        let mut measure_prepare_connection = Measure::start("prepare_connection");
318        let mut connection_try_count = 0;
319        let mut last_connection_id = 0;
320        let mut last_error = None;
321        while connection_try_count < 2 {
322            let connection = {
323                let mut conn_guard = self.connection.lock().await;
324
325                let maybe_conn = conn_guard.as_mut();
326                match maybe_conn {
327                    Some(conn) => {
328                        if conn.connection.stable_id() == last_connection_id {
329                            // this is the problematic connection we had used before, create a new one
330                            let conn = conn.make_connection_0rtt(self.addr, stats).await;
331                            match conn {
332                                Ok(conn) => {
333                                    info!(
334                                        "Made 0rtt connection to {} with id {} try_count {}, last_connection_id: {}, last_error: {:?}",
335                                        self.addr,
336                                        conn.stable_id(),
337                                        connection_try_count,
338                                        last_connection_id,
339                                        last_error,
340                                    );
341                                    connection_try_count += 1;
342                                    conn
343                                }
344                                Err(err) => {
345                                    info!(
346                                        "Cannot make 0rtt connection to {}, error {:}",
347                                        self.addr, err
348                                    );
349                                    return Err(err);
350                                }
351                            }
352                        } else {
353                            stats.connection_reuse.fetch_add(1, Ordering::Relaxed);
354                            conn.connection.clone()
355                        }
356                    }
357                    None => {
358                        let conn = QuicNewConnection::make_connection(
359                            self.endpoint.clone(),
360                            self.addr,
361                            stats,
362                        )
363                        .await;
364                        match conn {
365                            Ok(conn) => {
366                                *conn_guard = Some(conn.clone());
367                                info!(
368                                    "Made connection to {} id {} try_count {}, from connection cache warming?: {}",
369                                    self.addr,
370                                    conn.connection.stable_id(),
371                                    connection_try_count,
372                                    data.is_empty(),
373                                );
374                                connection_try_count += 1;
375                                conn.connection.clone()
376                            }
377                            Err(err) => {
378                                info!("Cannot make connection to {}, error {:}, from connection cache warming?: {}",
379                                    self.addr, err, data.is_empty());
380                                return Err(err);
381                            }
382                        }
383                    }
384                }
385            };
386
387            let new_stats = connection.stats();
388
389            connection_stats
390                .total_client_stats
391                .congestion_events
392                .update_stat(
393                    &self.stats.congestion_events,
394                    new_stats.path.congestion_events,
395                );
396
397            connection_stats
398                .total_client_stats
399                .streams_blocked_uni
400                .update_stat(
401                    &self.stats.streams_blocked_uni,
402                    new_stats.frame_tx.streams_blocked_uni,
403                );
404
405            connection_stats
406                .total_client_stats
407                .data_blocked
408                .update_stat(&self.stats.data_blocked, new_stats.frame_tx.data_blocked);
409
410            connection_stats
411                .total_client_stats
412                .acks
413                .update_stat(&self.stats.acks, new_stats.frame_tx.acks);
414
415            if data.is_empty() {
416                // no need to send packet as it is only for warming connections
417                return Ok(connection);
418            }
419
420            last_connection_id = connection.stable_id();
421            measure_prepare_connection.stop();
422
423            match Self::_send_buffer_using_conn(data, &connection).await {
424                Ok(()) => {
425                    measure_send_packet.stop();
426                    stats.successful_packets.fetch_add(1, Ordering::Relaxed);
427                    stats
428                        .send_packets_us
429                        .fetch_add(measure_send_packet.as_us(), Ordering::Relaxed);
430                    stats
431                        .prepare_connection_us
432                        .fetch_add(measure_prepare_connection.as_us(), Ordering::Relaxed);
433                    trace!(
434                        "Succcessfully sent to {} with id {}, thread: {:?}, data len: {}, send_packet_us: {} prepare_connection_us: {}",
435                        self.addr,
436                        connection.stable_id(),
437                        thread::current().id(),
438                        data.len(),
439                        measure_send_packet.as_us(),
440                        measure_prepare_connection.as_us(),
441                    );
442
443                    return Ok(connection);
444                }
445                Err(err) => match err {
446                    QuicError::ConnectionError(_) => {
447                        last_error = Some(err);
448                    }
449                    _ => {
450                        info!(
451                            "Error sending to {} with id {}, error {:?} thread: {:?}",
452                            self.addr,
453                            connection.stable_id(),
454                            err,
455                            thread::current().id(),
456                        );
457                        return Err(err);
458                    }
459                },
460            }
461        }
462
463        // if we come here, that means we have exhausted maximum retries, return the error
464        info!(
465            "Ran into an error sending data {:?}, exhausted retries to {}",
466            last_error, self.addr
467        );
468        // If we get here but last_error is None, then we have a logic error
469        // in this function, so panic here with an expect to help debugging
470        Err(last_error.expect("QuicClient::_send_buffer last_error.expect"))
471    }
472
473    pub async fn send_buffer<T>(
474        &self,
475        data: T,
476        stats: &ClientStats,
477        connection_stats: Arc<ConnectionCacheStats>,
478    ) -> Result<(), ClientErrorKind>
479    where
480        T: AsRef<[u8]>,
481    {
482        self._send_buffer(data.as_ref(), stats, connection_stats)
483            .await
484            .map_err(Into::<ClientErrorKind>::into)?;
485        Ok(())
486    }
487
488    pub async fn send_batch<T>(
489        &self,
490        buffers: &[T],
491        stats: &ClientStats,
492        connection_stats: Arc<ConnectionCacheStats>,
493    ) -> Result<(), ClientErrorKind>
494    where
495        T: AsRef<[u8]>,
496    {
497        // Start off by "testing" the connection by sending the first buffer
498        // This will also connect to the server if not already connected
499        // and reconnect and retry if the first send attempt failed
500        // (for example due to a timed out connection), returning an error
501        // or the connection that was used to successfully send the buffer.
502        // We will use the returned connection to send the rest of the buffers in the batch
503        // to avoid touching the mutex in self, and not bother reconnecting if we fail along the way
504        // since testing even in the ideal GCE environment has found no cases
505        // where reconnecting and retrying in the middle of a batch send
506        // (i.e. we encounter a connection error in the middle of a batch send, which presumably cannot
507        // be due to a timed out connection) has succeeded
508        if buffers.is_empty() {
509            return Ok(());
510        }
511        let connection = self
512            ._send_buffer(buffers[0].as_ref(), stats, connection_stats)
513            .await
514            .map_err(Into::<ClientErrorKind>::into)?;
515
516        for data in buffers[1..buffers.len()].iter() {
517            Self::_send_buffer_using_conn(data.as_ref(), &connection).await?;
518        }
519        Ok(())
520    }
521
522    pub fn server_addr(&self) -> &SocketAddr {
523        &self.addr
524    }
525
526    pub fn stats(&self) -> Arc<ClientStats> {
527        self.stats.clone()
528    }
529}
530
531pub struct QuicClientConnection {
532    pub client: Arc<QuicClient>,
533    pub connection_stats: Arc<ConnectionCacheStats>,
534}
535
536impl QuicClientConnection {
537    pub fn base_stats(&self) -> Arc<ClientStats> {
538        self.client.stats()
539    }
540
541    pub fn connection_stats(&self) -> Arc<ConnectionCacheStats> {
542        self.connection_stats.clone()
543    }
544
545    pub fn new(
546        endpoint: Arc<QuicLazyInitializedEndpoint>,
547        addr: SocketAddr,
548        connection_stats: Arc<ConnectionCacheStats>,
549    ) -> Self {
550        let client = Arc::new(QuicClient::new(endpoint, addr));
551        Self::new_with_client(client, connection_stats)
552    }
553
554    pub fn new_with_client(
555        client: Arc<QuicClient>,
556        connection_stats: Arc<ConnectionCacheStats>,
557    ) -> Self {
558        Self {
559            client,
560            connection_stats,
561        }
562    }
563}
564
565#[async_trait]
566impl ClientConnection for QuicClientConnection {
567    fn server_addr(&self) -> &SocketAddr {
568        self.client.server_addr()
569    }
570
571    async fn send_data_batch(&self, buffers: &[Vec<u8>]) -> TransportResult<()> {
572        let stats = ClientStats::default();
573        let len = buffers.len();
574        let res = self
575            .client
576            .send_batch(buffers, &stats, self.connection_stats.clone())
577            .await;
578        self.connection_stats
579            .add_client_stats(&stats, len, res.is_ok());
580        res?;
581        Ok(())
582    }
583
584    async fn send_data(&self, data: &[u8]) -> TransportResult<()> {
585        let stats = Arc::new(ClientStats::default());
586        // When data is empty which is from cache warmer, we are not sending packets actually, do not count it in
587        let num_packets = if data.is_empty() { 0 } else { 1 };
588        self.client
589            .send_buffer(data, &stats, self.connection_stats.clone())
590            .map_ok(|v| {
591                self.connection_stats
592                    .add_client_stats(&stats, num_packets, true);
593                v
594            })
595            .map_err(|e| {
596                warn!(
597                    "Failed to send data async to {}, error: {:?} ",
598                    self.server_addr(),
599                    e
600                );
601                datapoint_warn!("send-wire-async", ("failure", 1, i64),);
602                self.connection_stats
603                    .add_client_stats(&stats, num_packets, false);
604                e.into()
605            })
606            .await
607    }
608}