1use {
5 crate::{
6 client_error::ClientErrorKind, connection_cache::ConnectionCacheStats,
7 nonblocking::tpu_connection::TpuConnection, tpu_connection::ClientStats,
8 },
9 async_mutex::Mutex,
10 async_trait::async_trait,
11 futures::future::join_all,
12 itertools::Itertools,
13 log::*,
14 quinn::{
15 ClientConfig, ConnectError, ConnectionError, Endpoint, EndpointConfig, IdleTimeout,
16 NewConnection, VarInt, WriteError,
17 },
18 safecoin_measure::measure::Measure,
19 safecoin_net_utils::VALIDATOR_PORT_RANGE,
20 solana_sdk::{
21 quic::{
22 QUIC_CONNECTION_HANDSHAKE_TIMEOUT_MS, QUIC_KEEP_ALIVE_MS, QUIC_MAX_TIMEOUT_MS,
23 QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
24 },
25 signature::Keypair,
26 transport::Result as TransportResult,
27 },
28 solana_streamer::{
29 nonblocking::quic::ALPN_TPU_PROTOCOL_ID,
30 tls_certificates::new_self_signed_tls_certificate_chain,
31 },
32 std::{
33 net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket},
34 sync::{atomic::Ordering, Arc},
35 thread,
36 time::Duration,
37 },
38 thiserror::Error,
39 tokio::{sync::RwLock, time::timeout},
40};
41
42struct SkipServerVerification;
43
44impl SkipServerVerification {
45 pub fn new() -> Arc<Self> {
46 Arc::new(Self)
47 }
48}
49
50impl rustls::client::ServerCertVerifier for SkipServerVerification {
51 fn verify_server_cert(
52 &self,
53 _end_entity: &rustls::Certificate,
54 _intermediates: &[rustls::Certificate],
55 _server_name: &rustls::ServerName,
56 _scts: &mut dyn Iterator<Item = &[u8]>,
57 _ocsp_response: &[u8],
58 _now: std::time::SystemTime,
59 ) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
60 Ok(rustls::client::ServerCertVerified::assertion())
61 }
62}
63
64pub struct QuicClientCertificate {
65 pub certificates: Vec<rustls::Certificate>,
66 pub key: rustls::PrivateKey,
67}
68
69pub struct QuicLazyInitializedEndpoint {
71 endpoint: RwLock<Option<Arc<Endpoint>>>,
72 client_certificate: Arc<QuicClientCertificate>,
73}
74
75#[derive(Error, Debug)]
76pub enum QuicError {
77 #[error(transparent)]
78 WriteError(#[from] WriteError),
79 #[error(transparent)]
80 ConnectionError(#[from] ConnectionError),
81 #[error(transparent)]
82 ConnectError(#[from] ConnectError),
83}
84
85impl QuicLazyInitializedEndpoint {
86 pub fn new(client_certificate: Arc<QuicClientCertificate>) -> Self {
87 Self {
88 endpoint: RwLock::new(None),
89 client_certificate,
90 }
91 }
92
93 fn create_endpoint(&self) -> Endpoint {
94 let (_, client_socket) = safecoin_net_utils::bind_in_range(
95 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
96 VALIDATOR_PORT_RANGE,
97 )
98 .expect("QuicLazyInitializedEndpoint::create_endpoint bind_in_range");
99
100 let mut crypto = rustls::ClientConfig::builder()
101 .with_safe_defaults()
102 .with_custom_certificate_verifier(SkipServerVerification::new())
103 .with_single_cert(
104 self.client_certificate.certificates.clone(),
105 self.client_certificate.key.clone(),
106 )
107 .expect("Failed to set QUIC client certificates");
108 crypto.enable_early_data = true;
109 crypto.alpn_protocols = vec![ALPN_TPU_PROTOCOL_ID.to_vec()];
110
111 let mut endpoint =
112 QuicNewConnection::create_endpoint(EndpointConfig::default(), client_socket);
113
114 let mut config = ClientConfig::new(Arc::new(crypto));
115 let transport_config = Arc::get_mut(&mut config.transport)
116 .expect("QuicLazyInitializedEndpoint::create_endpoint Arc::get_mut");
117 let timeout = IdleTimeout::from(VarInt::from_u32(QUIC_MAX_TIMEOUT_MS));
118 transport_config.max_idle_timeout(Some(timeout));
119 transport_config.keep_alive_interval(Some(Duration::from_millis(QUIC_KEEP_ALIVE_MS)));
120
121 endpoint.set_default_client_config(config);
122 endpoint
123 }
124
125 async fn get_endpoint(&self) -> Arc<Endpoint> {
126 let lock = self.endpoint.read().await;
127 let endpoint = lock.as_ref();
128
129 match endpoint {
130 Some(endpoint) => endpoint.clone(),
131 None => {
132 drop(lock);
133 let mut lock = self.endpoint.write().await;
134 let endpoint = lock.as_ref();
135
136 match endpoint {
137 Some(endpoint) => endpoint.clone(),
138 None => {
139 let connection = Arc::new(self.create_endpoint());
140 *lock = Some(connection.clone());
141 connection
142 }
143 }
144 }
145 }
146 }
147}
148
149impl Default for QuicLazyInitializedEndpoint {
150 fn default() -> Self {
151 let (certs, priv_key) = new_self_signed_tls_certificate_chain(
152 &Keypair::new(),
153 IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
154 )
155 .expect("Failed to create QUIC client certificate");
156 Self::new(Arc::new(QuicClientCertificate {
157 certificates: certs,
158 key: priv_key,
159 }))
160 }
161}
162
163#[derive(Clone)]
166struct QuicNewConnection {
167 endpoint: Arc<Endpoint>,
168 connection: Arc<NewConnection>,
169}
170
171impl QuicNewConnection {
172 async fn make_connection(
174 endpoint: Arc<QuicLazyInitializedEndpoint>,
175 addr: SocketAddr,
176 stats: &ClientStats,
177 ) -> Result<Self, QuicError> {
178 let mut make_connection_measure = Measure::start("make_connection_measure");
179 let endpoint = endpoint.get_endpoint().await;
180
181 let connecting = endpoint.connect(addr, "connect")?;
182 stats.total_connections.fetch_add(1, Ordering::Relaxed);
183 if let Ok(connecting_result) = timeout(
184 Duration::from_millis(QUIC_CONNECTION_HANDSHAKE_TIMEOUT_MS),
185 connecting,
186 )
187 .await
188 {
189 if connecting_result.is_err() {
190 stats.connection_errors.fetch_add(1, Ordering::Relaxed);
191 }
192 make_connection_measure.stop();
193 stats
194 .make_connection_ms
195 .fetch_add(make_connection_measure.as_ms(), Ordering::Relaxed);
196
197 let connection = connecting_result?;
198
199 Ok(Self {
200 endpoint,
201 connection: Arc::new(connection),
202 })
203 } else {
204 Err(ConnectionError::TimedOut.into())
205 }
206 }
207
208 fn create_endpoint(config: EndpointConfig, client_socket: UdpSocket) -> Endpoint {
209 quinn::Endpoint::new(config, None, client_socket)
210 .expect("QuicNewConnection::create_endpoint quinn::Endpoint::new")
211 .0
212 }
213
214 async fn make_connection_0rtt(
217 &mut self,
218 addr: SocketAddr,
219 stats: &ClientStats,
220 ) -> Result<Arc<NewConnection>, QuicError> {
221 let connecting = self.endpoint.connect(addr, "connect")?;
222 stats.total_connections.fetch_add(1, Ordering::Relaxed);
223 let connection = match connecting.into_0rtt() {
224 Ok((connection, zero_rtt)) => {
225 if let Ok(zero_rtt) = timeout(
226 Duration::from_millis(QUIC_CONNECTION_HANDSHAKE_TIMEOUT_MS),
227 zero_rtt,
228 )
229 .await
230 {
231 if zero_rtt {
232 stats.zero_rtt_accepts.fetch_add(1, Ordering::Relaxed);
233 } else {
234 stats.zero_rtt_rejects.fetch_add(1, Ordering::Relaxed);
235 }
236 connection
237 } else {
238 return Err(ConnectionError::TimedOut.into());
239 }
240 }
241 Err(connecting) => {
242 stats.connection_errors.fetch_add(1, Ordering::Relaxed);
243
244 if let Ok(connecting_result) = timeout(
245 Duration::from_millis(QUIC_CONNECTION_HANDSHAKE_TIMEOUT_MS),
246 connecting,
247 )
248 .await
249 {
250 connecting_result?
251 } else {
252 return Err(ConnectionError::TimedOut.into());
253 }
254 }
255 };
256 self.connection = Arc::new(connection);
257 Ok(self.connection.clone())
258 }
259}
260
261pub struct QuicClient {
262 endpoint: Arc<QuicLazyInitializedEndpoint>,
263 connection: Arc<Mutex<Option<QuicNewConnection>>>,
264 addr: SocketAddr,
265 stats: Arc<ClientStats>,
266 chunk_size: usize,
267}
268
269impl QuicClient {
270 pub fn new(
271 endpoint: Arc<QuicLazyInitializedEndpoint>,
272 addr: SocketAddr,
273 chunk_size: usize,
274 ) -> Self {
275 Self {
276 endpoint,
277 connection: Arc::new(Mutex::new(None)),
278 addr,
279 stats: Arc::new(ClientStats::default()),
280 chunk_size,
281 }
282 }
283
284 async fn _send_buffer_using_conn(
285 data: &[u8],
286 connection: &NewConnection,
287 ) -> Result<(), QuicError> {
288 let mut send_stream = connection.connection.open_uni().await?;
289
290 send_stream.write_all(data).await?;
291 send_stream.finish().await?;
292 Ok(())
293 }
294
295 async fn _send_buffer(
298 &self,
299 data: &[u8],
300 stats: &ClientStats,
301 connection_stats: Arc<ConnectionCacheStats>,
302 ) -> Result<Arc<NewConnection>, QuicError> {
303 let mut connection_try_count = 0;
304 let mut last_connection_id = 0;
305 let mut last_error = None;
306
307 while connection_try_count < 2 {
308 let connection = {
309 let mut conn_guard = self.connection.lock().await;
310
311 let maybe_conn = conn_guard.as_mut();
312 match maybe_conn {
313 Some(conn) => {
314 if conn.connection.connection.stable_id() == last_connection_id {
315 let conn = conn.make_connection_0rtt(self.addr, stats).await;
317 match conn {
318 Ok(conn) => {
319 info!(
320 "Made 0rtt connection to {} with id {} try_count {}, last_connection_id: {}, last_error: {:?}",
321 self.addr,
322 conn.connection.stable_id(),
323 connection_try_count,
324 last_connection_id,
325 last_error,
326 );
327 connection_try_count += 1;
328 conn
329 }
330 Err(err) => {
331 info!(
332 "Cannot make 0rtt connection to {}, error {:}",
333 self.addr, err
334 );
335 return Err(err);
336 }
337 }
338 } else {
339 stats.connection_reuse.fetch_add(1, Ordering::Relaxed);
340 conn.connection.clone()
341 }
342 }
343 None => {
344 let conn = QuicNewConnection::make_connection(
345 self.endpoint.clone(),
346 self.addr,
347 stats,
348 )
349 .await;
350 match conn {
351 Ok(conn) => {
352 *conn_guard = Some(conn.clone());
353 info!(
354 "Made connection to {} id {} try_count {}",
355 self.addr,
356 conn.connection.connection.stable_id(),
357 connection_try_count
358 );
359 connection_try_count += 1;
360 conn.connection.clone()
361 }
362 Err(err) => {
363 info!("Cannot make connection to {}, error {:}", self.addr, err);
364 return Err(err);
365 }
366 }
367 }
368 }
369 };
370
371 let new_stats = connection.connection.stats();
372
373 connection_stats
374 .total_client_stats
375 .congestion_events
376 .update_stat(
377 &self.stats.congestion_events,
378 new_stats.path.congestion_events,
379 );
380
381 connection_stats
382 .total_client_stats
383 .tx_streams_blocked_uni
384 .update_stat(
385 &self.stats.tx_streams_blocked_uni,
386 new_stats.frame_tx.streams_blocked_uni,
387 );
388
389 connection_stats
390 .total_client_stats
391 .tx_data_blocked
392 .update_stat(&self.stats.tx_data_blocked, new_stats.frame_tx.data_blocked);
393
394 connection_stats
395 .total_client_stats
396 .tx_acks
397 .update_stat(&self.stats.tx_acks, new_stats.frame_tx.acks);
398
399 last_connection_id = connection.connection.stable_id();
400 match Self::_send_buffer_using_conn(data, &connection).await {
401 Ok(()) => {
402 return Ok(connection);
403 }
404 Err(err) => match err {
405 QuicError::ConnectionError(_) => {
406 last_error = Some(err);
407 }
408 _ => {
409 info!(
410 "Error sending to {} with id {}, error {:?} thread: {:?}",
411 self.addr,
412 connection.connection.stable_id(),
413 err,
414 thread::current().id(),
415 );
416 return Err(err);
417 }
418 },
419 }
420 }
421
422 info!(
424 "Ran into an error sending transactions {:?}, exhausted retries to {}",
425 last_error, self.addr
426 );
427 Err(last_error.expect("QuicClient::_send_buffer last_error.expect"))
430 }
431
432 pub async fn send_buffer<T>(
433 &self,
434 data: T,
435 stats: &ClientStats,
436 connection_stats: Arc<ConnectionCacheStats>,
437 ) -> Result<(), ClientErrorKind>
438 where
439 T: AsRef<[u8]>,
440 {
441 self._send_buffer(data.as_ref(), stats, connection_stats)
442 .await?;
443 Ok(())
444 }
445
446 pub async fn send_batch<T>(
447 &self,
448 buffers: &[T],
449 stats: &ClientStats,
450 connection_stats: Arc<ConnectionCacheStats>,
451 ) -> Result<(), ClientErrorKind>
452 where
453 T: AsRef<[u8]>,
454 {
455 if buffers.is_empty() {
467 return Ok(());
468 }
469 let connection = self
470 ._send_buffer(buffers[0].as_ref(), stats, connection_stats)
471 .await?;
472
473 let connection_ref: &NewConnection = &connection;
476
477 let chunks = buffers[1..buffers.len()].iter().chunks(self.chunk_size);
478
479 let futures: Vec<_> = chunks
480 .into_iter()
481 .map(|buffs| {
482 join_all(
483 buffs
484 .into_iter()
485 .map(|buf| Self::_send_buffer_using_conn(buf.as_ref(), connection_ref)),
486 )
487 })
488 .collect();
489
490 for f in futures {
491 f.await.into_iter().try_for_each(|res| res)?;
492 }
493 Ok(())
494 }
495
496 pub fn tpu_addr(&self) -> &SocketAddr {
497 &self.addr
498 }
499
500 pub fn stats(&self) -> Arc<ClientStats> {
501 self.stats.clone()
502 }
503}
504
505pub struct QuicTpuConnection {
506 client: Arc<QuicClient>,
507 connection_stats: Arc<ConnectionCacheStats>,
508}
509
510impl QuicTpuConnection {
511 pub fn base_stats(&self) -> Arc<ClientStats> {
512 self.client.stats()
513 }
514
515 pub fn connection_stats(&self) -> Arc<ConnectionCacheStats> {
516 self.connection_stats.clone()
517 }
518
519 pub fn new(
520 endpoint: Arc<QuicLazyInitializedEndpoint>,
521 addr: SocketAddr,
522 connection_stats: Arc<ConnectionCacheStats>,
523 ) -> Self {
524 let client = Arc::new(QuicClient::new(
525 endpoint,
526 addr,
527 QUIC_MAX_UNSTAKED_CONCURRENT_STREAMS,
528 ));
529 Self::new_with_client(client, connection_stats)
530 }
531
532 pub fn new_with_client(
533 client: Arc<QuicClient>,
534 connection_stats: Arc<ConnectionCacheStats>,
535 ) -> Self {
536 Self {
537 client,
538 connection_stats,
539 }
540 }
541}
542
543#[async_trait]
544impl TpuConnection for QuicTpuConnection {
545 fn tpu_addr(&self) -> &SocketAddr {
546 self.client.tpu_addr()
547 }
548
549 async fn send_wire_transaction_batch<T>(&self, buffers: &[T]) -> TransportResult<()>
550 where
551 T: AsRef<[u8]> + Send + Sync,
552 {
553 let stats = ClientStats::default();
554 let len = buffers.len();
555 let res = self
556 .client
557 .send_batch(buffers, &stats, self.connection_stats.clone())
558 .await;
559 self.connection_stats
560 .add_client_stats(&stats, len, res.is_ok());
561 res?;
562 Ok(())
563 }
564
565 async fn send_wire_transaction<T>(&self, wire_transaction: T) -> TransportResult<()>
566 where
567 T: AsRef<[u8]> + Send + Sync,
568 {
569 let stats = Arc::new(ClientStats::default());
570 let send_buffer =
571 self.client
572 .send_buffer(wire_transaction, &stats, self.connection_stats.clone());
573 if let Err(e) = send_buffer.await {
574 warn!(
575 "Failed to send transaction async to {}, error: {:?} ",
576 self.tpu_addr(),
577 e
578 );
579 datapoint_warn!("send-wire-async", ("failure", 1, i64),);
580 self.connection_stats.add_client_stats(&stats, 1, false);
581 } else {
582 self.connection_stats.add_client_stats(&stats, 1, true);
583 }
584 Ok(())
585 }
586}