solana_tpu_client/nonblocking/
tpu_client.rs

1pub use crate::tpu_client::Result;
2use {
3    crate::tpu_client::{RecentLeaderSlots, TpuClientConfig, MAX_FANOUT_SLOTS},
4    bincode::serialize,
5    futures_util::{future::join_all, stream::StreamExt},
6    log::*,
7    solana_clock::{Slot, DEFAULT_MS_PER_SLOT, NUM_CONSECUTIVE_LEADER_SLOTS},
8    solana_commitment_config::CommitmentConfig,
9    solana_connection_cache::{
10        connection_cache::{
11            ConnectionCache, ConnectionManager, ConnectionPool, NewConnectionConfig, Protocol,
12            DEFAULT_CONNECTION_POOL_SIZE,
13        },
14        nonblocking::client_connection::ClientConnection,
15    },
16    solana_epoch_info::EpochInfo,
17    solana_pubkey::Pubkey,
18    solana_pubsub_client::nonblocking::pubsub_client::{PubsubClient, PubsubClientError},
19    solana_quic_definitions::QUIC_PORT_OFFSET,
20    solana_rpc_client::nonblocking::rpc_client::RpcClient,
21    solana_rpc_client_api::{
22        client_error::{Error as ClientError, ErrorKind, Result as ClientResult},
23        request::RpcError,
24        response::{RpcContactInfo, SlotUpdate},
25    },
26    solana_signer::SignerError,
27    solana_transaction::Transaction,
28    solana_transaction_error::{TransportError, TransportResult},
29    std::{
30        collections::{HashMap, HashSet},
31        net::SocketAddr,
32        str::FromStr,
33        sync::{
34            atomic::{AtomicBool, Ordering},
35            Arc, RwLock,
36        },
37    },
38    thiserror::Error,
39    tokio::{
40        task::JoinHandle,
41        time::{sleep, timeout, Duration, Instant},
42    },
43};
44#[cfg(feature = "spinner")]
45use {
46    crate::tpu_client::{SEND_TRANSACTION_INTERVAL, TRANSACTION_RESEND_INTERVAL},
47    futures_util::FutureExt,
48    indicatif::ProgressBar,
49    solana_message::Message,
50    solana_rpc_client::spinner::{self, SendTransactionProgress},
51    solana_rpc_client_api::request::MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS,
52    solana_signer::signers::Signers,
53    solana_transaction_error::TransactionError,
54    std::{future::Future, iter},
55};
56
57#[derive(Error, Debug)]
58pub enum TpuSenderError {
59    #[error("Pubsub error: {0:?}")]
60    PubsubError(#[from] PubsubClientError),
61    #[error("RPC error: {0:?}")]
62    RpcError(#[from] ClientError),
63    #[error("IO error: {0:?}")]
64    IoError(#[from] std::io::Error),
65    #[error("Signer error: {0:?}")]
66    SignerError(#[from] SignerError),
67    #[error("Custom error: {0}")]
68    Custom(String),
69}
70
71struct LeaderTpuCacheUpdateInfo {
72    pub(super) maybe_cluster_nodes: Option<ClientResult<Vec<RpcContactInfo>>>,
73    pub(super) maybe_epoch_info: Option<ClientResult<EpochInfo>>,
74    pub(super) maybe_slot_leaders: Option<ClientResult<Vec<Pubkey>>>,
75}
76impl LeaderTpuCacheUpdateInfo {
77    pub fn has_some(&self) -> bool {
78        self.maybe_cluster_nodes.is_some()
79            || self.maybe_epoch_info.is_some()
80            || self.maybe_slot_leaders.is_some()
81    }
82}
83
84struct LeaderTpuCache {
85    protocol: Protocol,
86    first_slot: Slot,
87    leaders: Vec<Pubkey>,
88    leader_tpu_map: HashMap<Pubkey, SocketAddr>,
89    slots_in_epoch: Slot,
90    last_epoch_info_slot: Slot,
91}
92
93impl LeaderTpuCache {
94    pub fn new(
95        first_slot: Slot,
96        slots_in_epoch: Slot,
97        leaders: Vec<Pubkey>,
98        cluster_nodes: Vec<RpcContactInfo>,
99        protocol: Protocol,
100    ) -> Self {
101        let leader_tpu_map = Self::extract_cluster_tpu_sockets(protocol, cluster_nodes);
102        Self {
103            protocol,
104            first_slot,
105            leaders,
106            leader_tpu_map,
107            slots_in_epoch,
108            last_epoch_info_slot: first_slot,
109        }
110    }
111
112    // Last slot that has a cached leader pubkey
113    pub fn last_slot(&self) -> Slot {
114        self.first_slot + self.leaders.len().saturating_sub(1) as u64
115    }
116
117    pub fn slot_info(&self) -> (Slot, Slot, Slot) {
118        (
119            self.last_slot(),
120            self.last_epoch_info_slot,
121            self.slots_in_epoch,
122        )
123    }
124
125    // Get the TPU sockets for the current leader and upcoming *unique* leaders according to fanout size.
126    fn get_unique_leader_sockets(
127        &self,
128        estimated_current_slot: Slot,
129        fanout_slots: u64,
130    ) -> Vec<SocketAddr> {
131        let all_leader_sockets = self.get_leader_sockets(estimated_current_slot, fanout_slots);
132
133        let mut unique_sockets = Vec::new();
134        let mut seen = HashSet::new();
135
136        for socket in all_leader_sockets {
137            if seen.insert(socket) {
138                unique_sockets.push(socket);
139            }
140        }
141
142        unique_sockets
143    }
144
145    // Get the TPU sockets for the current leader and upcoming leaders according to fanout size.
146    fn get_leader_sockets(
147        &self,
148        estimated_current_slot: Slot,
149        fanout_slots: u64,
150    ) -> Vec<SocketAddr> {
151        let mut leader_sockets = Vec::new();
152        // `first_slot` might have been advanced since caller last read the `estimated_current_slot`
153        // value. Take the greater of the two values to ensure we are reading from the latest
154        // leader schedule.
155        let current_slot = std::cmp::max(estimated_current_slot, self.first_slot);
156        for leader_slot in (current_slot..current_slot + fanout_slots)
157            .step_by(NUM_CONSECUTIVE_LEADER_SLOTS as usize)
158        {
159            if let Some(leader) = self.get_slot_leader(leader_slot) {
160                if let Some(tpu_socket) = self.leader_tpu_map.get(leader) {
161                    leader_sockets.push(*tpu_socket);
162                } else {
163                    // The leader is probably delinquent
164                    trace!("TPU not available for leader {}", leader);
165                }
166            } else {
167                // Overran the local leader schedule cache
168                warn!(
169                    "Leader not known for slot {}; cache holds slots [{},{}]",
170                    leader_slot,
171                    self.first_slot,
172                    self.last_slot()
173                );
174            }
175        }
176        leader_sockets
177    }
178
179    pub fn get_slot_leader(&self, slot: Slot) -> Option<&Pubkey> {
180        if slot >= self.first_slot {
181            let index = slot - self.first_slot;
182            self.leaders.get(index as usize)
183        } else {
184            None
185        }
186    }
187
188    fn extract_cluster_tpu_sockets(
189        protocol: Protocol,
190        cluster_contact_info: Vec<RpcContactInfo>,
191    ) -> HashMap<Pubkey, SocketAddr> {
192        cluster_contact_info
193            .into_iter()
194            .filter_map(|contact_info| {
195                let pubkey = Pubkey::from_str(&contact_info.pubkey).ok()?;
196                let socket = match protocol {
197                    Protocol::QUIC => contact_info.tpu_quic.or_else(|| {
198                        let mut socket = contact_info.tpu?;
199                        let port = socket.port().checked_add(QUIC_PORT_OFFSET)?;
200                        socket.set_port(port);
201                        Some(socket)
202                    }),
203                    Protocol::UDP => contact_info.tpu,
204                }?;
205                Some((pubkey, socket))
206            })
207            .collect()
208    }
209
210    pub fn fanout(slots_in_epoch: Slot) -> Slot {
211        (2 * MAX_FANOUT_SLOTS).min(slots_in_epoch)
212    }
213
214    pub fn update_all(
215        &mut self,
216        estimated_current_slot: Slot,
217        cache_update_info: LeaderTpuCacheUpdateInfo,
218    ) -> (bool, bool) {
219        let mut has_error = false;
220        let mut cluster_refreshed = false;
221        if let Some(cluster_nodes) = cache_update_info.maybe_cluster_nodes {
222            match cluster_nodes {
223                Ok(cluster_nodes) => {
224                    self.leader_tpu_map =
225                        Self::extract_cluster_tpu_sockets(self.protocol, cluster_nodes);
226                    cluster_refreshed = true;
227                }
228                Err(err) => {
229                    warn!("Failed to fetch cluster tpu sockets: {}", err);
230                    has_error = true;
231                }
232            }
233        }
234
235        if let Some(Ok(epoch_info)) = cache_update_info.maybe_epoch_info {
236            self.slots_in_epoch = epoch_info.slots_in_epoch;
237            self.last_epoch_info_slot = estimated_current_slot;
238        }
239
240        if let Some(slot_leaders) = cache_update_info.maybe_slot_leaders {
241            match slot_leaders {
242                Ok(slot_leaders) => {
243                    self.first_slot = estimated_current_slot;
244                    self.leaders = slot_leaders;
245                }
246                Err(err) => {
247                    warn!(
248                        "Failed to fetch slot leaders (current estimated slot: {}): {}",
249                        estimated_current_slot, err
250                    );
251                    has_error = true;
252                }
253            }
254        }
255        (has_error, cluster_refreshed)
256    }
257}
258
259/// Client which sends transactions directly to the current leader's TPU port over UDP.
260/// The client uses RPC to determine the current leader and fetch node contact info
261pub struct TpuClient<
262    P, // ConnectionPool
263    M, // ConnectionManager
264    C, // NewConnectionConfig
265> {
266    fanout_slots: u64,
267    leader_tpu_service: LeaderTpuService,
268    exit: Arc<AtomicBool>,
269    rpc_client: Arc<RpcClient>,
270    connection_cache: Arc<ConnectionCache<P, M, C>>,
271}
272
273/// Helper function which generates futures to all be awaited together for maximum
274/// throughput
275#[cfg(feature = "spinner")]
276fn send_wire_transaction_futures<'a, P, M, C>(
277    progress_bar: &'a ProgressBar,
278    progress: &'a SendTransactionProgress,
279    index: usize,
280    num_transactions: usize,
281    wire_transaction: Vec<u8>,
282    leaders: Vec<SocketAddr>,
283    connection_cache: &'a ConnectionCache<P, M, C>,
284) -> Vec<impl Future<Output = TransportResult<()>> + 'a>
285where
286    P: ConnectionPool<NewConnectionConfig = C>,
287    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
288    C: NewConnectionConfig,
289{
290    const SEND_TIMEOUT_INTERVAL: Duration = Duration::from_secs(5);
291    let sleep_duration = SEND_TRANSACTION_INTERVAL.saturating_mul(index as u32);
292    let send_timeout = SEND_TIMEOUT_INTERVAL.saturating_add(sleep_duration);
293    leaders
294        .into_iter()
295        .map(|addr| {
296            timeout_future(
297                send_timeout,
298                sleep_and_send_wire_transaction_to_addr(
299                    sleep_duration,
300                    connection_cache,
301                    addr,
302                    wire_transaction.clone(),
303                ),
304            )
305            .boxed_local() // required to make types work simply
306        })
307        .chain(iter::once(
308            timeout_future(
309                send_timeout,
310                sleep_and_set_message(
311                    sleep_duration,
312                    progress_bar,
313                    progress,
314                    index,
315                    num_transactions,
316                ),
317            )
318            .boxed_local(), // required to make types work simply
319        ))
320        .collect::<Vec<_>>()
321}
322
323// Wrap an existing future with a timeout.
324//
325// Useful for end-users who don't need a persistent connection to each validator,
326// and want to abort more quickly.
327#[cfg(feature = "spinner")]
328async fn timeout_future<Fut: Future<Output = TransportResult<()>>>(
329    timeout_duration: Duration,
330    future: Fut,
331) -> TransportResult<()> {
332    timeout(timeout_duration, future)
333        .await
334        .unwrap_or_else(|_| Err(TransportError::Custom("Timed out".to_string())))
335}
336
337#[cfg(feature = "spinner")]
338async fn sleep_and_set_message(
339    sleep_duration: Duration,
340    progress_bar: &ProgressBar,
341    progress: &SendTransactionProgress,
342    index: usize,
343    num_transactions: usize,
344) -> TransportResult<()> {
345    sleep(sleep_duration).await;
346    progress.set_message_for_confirmed_transactions(
347        progress_bar,
348        &format!("Sending {}/{} transactions", index + 1, num_transactions,),
349    );
350    Ok(())
351}
352
353#[cfg(feature = "spinner")]
354async fn sleep_and_send_wire_transaction_to_addr<P, M, C>(
355    sleep_duration: Duration,
356    connection_cache: &ConnectionCache<P, M, C>,
357    addr: SocketAddr,
358    wire_transaction: Vec<u8>,
359) -> TransportResult<()>
360where
361    P: ConnectionPool<NewConnectionConfig = C>,
362    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
363    C: NewConnectionConfig,
364{
365    sleep(sleep_duration).await;
366    send_wire_transaction_to_addr(connection_cache, &addr, wire_transaction).await
367}
368
369async fn send_wire_transaction_to_addr<P, M, C>(
370    connection_cache: &ConnectionCache<P, M, C>,
371    addr: &SocketAddr,
372    wire_transaction: Vec<u8>,
373) -> TransportResult<()>
374where
375    P: ConnectionPool<NewConnectionConfig = C>,
376    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
377    C: NewConnectionConfig,
378{
379    let conn = connection_cache.get_nonblocking_connection(addr);
380    conn.send_data(&wire_transaction).await
381}
382
383async fn send_wire_transaction_batch_to_addr<P, M, C>(
384    connection_cache: &ConnectionCache<P, M, C>,
385    addr: &SocketAddr,
386    wire_transactions: &[Vec<u8>],
387) -> TransportResult<()>
388where
389    P: ConnectionPool<NewConnectionConfig = C>,
390    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
391    C: NewConnectionConfig,
392{
393    let conn = connection_cache.get_nonblocking_connection(addr);
394    conn.send_data_batch(wire_transactions).await
395}
396
397impl<P, M, C> TpuClient<P, M, C>
398where
399    P: ConnectionPool<NewConnectionConfig = C>,
400    M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
401    C: NewConnectionConfig,
402{
403    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
404    /// size
405    pub async fn send_transaction(&self, transaction: &Transaction) -> bool {
406        let wire_transaction = serialize(transaction).expect("serialization should succeed");
407        self.send_wire_transaction(wire_transaction).await
408    }
409
410    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
411    pub async fn send_wire_transaction(&self, wire_transaction: Vec<u8>) -> bool {
412        self.try_send_wire_transaction(wire_transaction)
413            .await
414            .is_ok()
415    }
416
417    /// Serialize and send transaction to the current and upcoming leader TPUs according to fanout
418    /// size
419    /// Returns the last error if all sends fail
420    pub async fn try_send_transaction(&self, transaction: &Transaction) -> TransportResult<()> {
421        let wire_transaction = serialize(transaction).expect("serialization should succeed");
422        self.try_send_wire_transaction(wire_transaction).await
423    }
424
425    /// Send a wire transaction to the current and upcoming leader TPUs according to fanout size
426    /// Returns the last error if all sends fail
427    pub async fn try_send_wire_transaction(
428        &self,
429        wire_transaction: Vec<u8>,
430    ) -> TransportResult<()> {
431        let leaders = self
432            .leader_tpu_service
433            .unique_leader_tpu_sockets(self.fanout_slots);
434        let futures = leaders
435            .iter()
436            .map(|addr| {
437                send_wire_transaction_to_addr(
438                    &self.connection_cache,
439                    addr,
440                    wire_transaction.clone(),
441                )
442            })
443            .collect::<Vec<_>>();
444        let results: Vec<TransportResult<()>> = join_all(futures).await;
445
446        let mut last_error: Option<TransportError> = None;
447        let mut some_success = false;
448        for result in results {
449            if let Err(e) = result {
450                if last_error.is_none() {
451                    last_error = Some(e);
452                }
453            } else {
454                some_success = true;
455            }
456        }
457        if !some_success {
458            Err(if let Some(err) = last_error {
459                err
460            } else {
461                std::io::Error::new(std::io::ErrorKind::Other, "No sends attempted").into()
462            })
463        } else {
464            Ok(())
465        }
466    }
467
468    /// Send a batch of wire transactions to the current and upcoming leader TPUs according to
469    /// fanout size
470    /// Returns the last error if all sends fail
471    pub async fn try_send_wire_transaction_batch(
472        &self,
473        wire_transactions: Vec<Vec<u8>>,
474    ) -> TransportResult<()> {
475        let leaders = self
476            .leader_tpu_service
477            .unique_leader_tpu_sockets(self.fanout_slots);
478        let futures = leaders
479            .iter()
480            .map(|addr| {
481                send_wire_transaction_batch_to_addr(
482                    &self.connection_cache,
483                    addr,
484                    &wire_transactions,
485                )
486            })
487            .collect::<Vec<_>>();
488        let results: Vec<TransportResult<()>> = join_all(futures).await;
489
490        let mut last_error: Option<TransportError> = None;
491        let mut some_success = false;
492        for result in results {
493            if let Err(e) = result {
494                if last_error.is_none() {
495                    last_error = Some(e);
496                }
497            } else {
498                some_success = true;
499            }
500        }
501        if !some_success {
502            Err(if let Some(err) = last_error {
503                err
504            } else {
505                std::io::Error::new(std::io::ErrorKind::Other, "No sends attempted").into()
506            })
507        } else {
508            Ok(())
509        }
510    }
511
512    /// Create a new client that disconnects when dropped
513    pub async fn new(
514        name: &'static str,
515        rpc_client: Arc<RpcClient>,
516        websocket_url: &str,
517        config: TpuClientConfig,
518        connection_manager: M,
519    ) -> Result<Self> {
520        let connection_cache = Arc::new(
521            ConnectionCache::new(name, connection_manager, DEFAULT_CONNECTION_POOL_SIZE).unwrap(),
522        ); // TODO: Handle error properly, as the ConnectionCache ctor is now fallible.
523        Self::new_with_connection_cache(rpc_client, websocket_url, config, connection_cache).await
524    }
525
526    /// Create a new client that disconnects when dropped
527    pub async fn new_with_connection_cache(
528        rpc_client: Arc<RpcClient>,
529        websocket_url: &str,
530        config: TpuClientConfig,
531        connection_cache: Arc<ConnectionCache<P, M, C>>,
532    ) -> Result<Self> {
533        let exit = Arc::new(AtomicBool::new(false));
534        let leader_tpu_service =
535            LeaderTpuService::new(rpc_client.clone(), websocket_url, M::PROTOCOL, exit.clone())
536                .await?;
537
538        Ok(Self {
539            fanout_slots: config.fanout_slots.clamp(1, MAX_FANOUT_SLOTS),
540            leader_tpu_service,
541            exit,
542            rpc_client,
543            connection_cache,
544        })
545    }
546
547    #[cfg(feature = "spinner")]
548    pub async fn send_and_confirm_messages_with_spinner<T: Signers + ?Sized>(
549        &self,
550        messages: &[Message],
551        signers: &T,
552    ) -> Result<Vec<Option<TransactionError>>> {
553        let mut progress = SendTransactionProgress::default();
554        let progress_bar = spinner::new_progress_bar();
555        progress_bar.set_message("Setting up...");
556
557        let mut transactions = messages
558            .iter()
559            .enumerate()
560            .map(|(i, message)| (i, Transaction::new_unsigned(message.clone())))
561            .collect::<Vec<_>>();
562        progress.total_transactions = transactions.len();
563        let mut transaction_errors = vec![None; transactions.len()];
564        progress.block_height = self.rpc_client.get_block_height().await?;
565        for expired_blockhash_retries in (0..5).rev() {
566            let (blockhash, last_valid_block_height) = self
567                .rpc_client
568                .get_latest_blockhash_with_commitment(self.rpc_client.commitment())
569                .await?;
570            progress.last_valid_block_height = last_valid_block_height;
571
572            let mut pending_transactions = HashMap::new();
573            for (i, mut transaction) in transactions {
574                transaction.try_sign(signers, blockhash)?;
575                pending_transactions.insert(transaction.signatures[0], (i, transaction));
576            }
577
578            let mut last_resend = Instant::now() - TRANSACTION_RESEND_INTERVAL;
579            while progress.block_height <= progress.last_valid_block_height {
580                let num_transactions = pending_transactions.len();
581
582                // Periodically re-send all pending transactions
583                if Instant::now().duration_since(last_resend) > TRANSACTION_RESEND_INTERVAL {
584                    // Prepare futures for all transactions
585                    let mut futures = vec![];
586                    for (index, (_i, transaction)) in pending_transactions.values().enumerate() {
587                        let wire_transaction = serialize(transaction).unwrap();
588                        let leaders = self
589                            .leader_tpu_service
590                            .unique_leader_tpu_sockets(self.fanout_slots);
591                        futures.extend(send_wire_transaction_futures(
592                            &progress_bar,
593                            &progress,
594                            index,
595                            num_transactions,
596                            wire_transaction,
597                            leaders,
598                            &self.connection_cache,
599                        ));
600                    }
601
602                    // Start the process of sending them all
603                    let results = join_all(futures).await;
604
605                    progress.set_message_for_confirmed_transactions(
606                        &progress_bar,
607                        "Checking sent transactions",
608                    );
609                    for (index, (tx_results, (_i, transaction))) in results
610                        .chunks(self.fanout_slots as usize)
611                        .zip(pending_transactions.values())
612                        .enumerate()
613                    {
614                        // Only report an error if every future in the chunk errored
615                        if tx_results.iter().all(|r| r.is_err()) {
616                            progress.set_message_for_confirmed_transactions(
617                                &progress_bar,
618                                &format!(
619                                    "Resending failed transaction {} of {}",
620                                    index + 1,
621                                    num_transactions,
622                                ),
623                            );
624                            let _result = self.rpc_client.send_transaction(transaction).await.ok();
625                        }
626                    }
627                    last_resend = Instant::now();
628                }
629
630                // Wait for the next block before checking for transaction statuses
631                let mut block_height_refreshes = 10;
632                progress.set_message_for_confirmed_transactions(
633                    &progress_bar,
634                    &format!("Waiting for next block, {num_transactions} transactions pending..."),
635                );
636                let mut new_block_height = progress.block_height;
637                while progress.block_height == new_block_height && block_height_refreshes > 0 {
638                    sleep(Duration::from_millis(500)).await;
639                    new_block_height = self.rpc_client.get_block_height().await?;
640                    block_height_refreshes -= 1;
641                }
642                progress.block_height = new_block_height;
643
644                // Collect statuses for the transactions, drop those that are confirmed
645                let pending_signatures = pending_transactions.keys().cloned().collect::<Vec<_>>();
646                for pending_signatures_chunk in
647                    pending_signatures.chunks(MAX_GET_SIGNATURE_STATUSES_QUERY_ITEMS)
648                {
649                    if let Ok(result) = self
650                        .rpc_client
651                        .get_signature_statuses(pending_signatures_chunk)
652                        .await
653                    {
654                        let statuses = result.value;
655                        for (signature, status) in
656                            pending_signatures_chunk.iter().zip(statuses.into_iter())
657                        {
658                            if let Some(status) = status {
659                                if status.satisfies_commitment(self.rpc_client.commitment()) {
660                                    if let Some((i, _)) = pending_transactions.remove(signature) {
661                                        progress.confirmed_transactions += 1;
662                                        if status.err.is_some() {
663                                            progress_bar
664                                                .println(format!("Failed transaction: {status:?}"));
665                                        }
666                                        transaction_errors[i] = status.err;
667                                    }
668                                }
669                            }
670                        }
671                    }
672                    progress.set_message_for_confirmed_transactions(
673                        &progress_bar,
674                        "Checking transaction status...",
675                    );
676                }
677
678                if pending_transactions.is_empty() {
679                    return Ok(transaction_errors);
680                }
681            }
682
683            transactions = pending_transactions.into_values().collect();
684            progress_bar.println(format!(
685                "Blockhash expired. {expired_blockhash_retries} retries remaining"
686            ));
687        }
688        Err(TpuSenderError::Custom("Max retries exceeded".into()))
689    }
690
691    pub fn rpc_client(&self) -> &RpcClient {
692        &self.rpc_client
693    }
694
695    pub async fn shutdown(&mut self) {
696        self.exit.store(true, Ordering::Relaxed);
697        self.leader_tpu_service.join().await;
698    }
699
700    pub fn get_connection_cache(&self) -> &Arc<ConnectionCache<P, M, C>>
701    where
702        P: ConnectionPool<NewConnectionConfig = C>,
703        M: ConnectionManager<ConnectionPool = P, NewConnectionConfig = C>,
704        C: NewConnectionConfig,
705    {
706        &self.connection_cache
707    }
708
709    pub fn get_leader_tpu_service(&self) -> &LeaderTpuService {
710        &self.leader_tpu_service
711    }
712
713    pub fn get_fanout_slots(&self) -> u64 {
714        self.fanout_slots
715    }
716}
717
718impl<P, M, C> Drop for TpuClient<P, M, C> {
719    fn drop(&mut self) {
720        self.exit.store(true, Ordering::Relaxed);
721    }
722}
723
724/// Service that tracks upcoming leaders and maintains an up-to-date mapping
725/// of leader id to TPU socket address.
726pub struct LeaderTpuService {
727    recent_slots: RecentLeaderSlots,
728    leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
729    t_leader_tpu_service: Option<JoinHandle<Result<()>>>,
730}
731
732impl LeaderTpuService {
733    pub async fn new(
734        rpc_client: Arc<RpcClient>,
735        websocket_url: &str,
736        protocol: Protocol,
737        exit: Arc<AtomicBool>,
738    ) -> Result<Self> {
739        let start_slot = rpc_client
740            .get_slot_with_commitment(CommitmentConfig::processed())
741            .await?;
742
743        let recent_slots = RecentLeaderSlots::new(start_slot);
744        let slots_in_epoch = rpc_client.get_epoch_info().await?.slots_in_epoch;
745
746        // When a cluster is starting, we observe an invalid slot range failure that goes away after a
747        // retry. It seems as if the leader schedule is not available, but it should be. The logic
748        // below retries the RPC call in case of an invalid slot range error.
749        let tpu_leader_service_creation_timeout = Duration::from_secs(20);
750        let retry_interval = Duration::from_secs(1);
751        let leaders = timeout(tpu_leader_service_creation_timeout, async {
752            loop {
753                // TODO: The root cause appears to lie within the `rpc_client.get_slot_leaders()`.
754                // It might be worth debugging further and trying to understand why the RPC
755                // call fails. There may be a bug in the `get_slot_leaders()` logic or in the
756                // RPC implementation
757                match rpc_client
758                    .get_slot_leaders(start_slot, LeaderTpuCache::fanout(slots_in_epoch))
759                    .await
760                {
761                    Ok(leaders) => return Ok(leaders),
762                    Err(client_error) => {
763                        if is_invalid_slot_range_error(&client_error) {
764                            sleep(retry_interval).await;
765                            continue;
766                        } else {
767                            return Err(client_error);
768                        }
769                    }
770                }
771            }
772        })
773        .await
774        .map_err(|_| {
775            TpuSenderError::Custom(format!(
776                "Failed to get slot leaders connecting to: {}, timeout: {:?}. Invalid slot range",
777                websocket_url, tpu_leader_service_creation_timeout
778            ))
779        })??;
780
781        let cluster_nodes = rpc_client.get_cluster_nodes().await?;
782        let leader_tpu_cache = Arc::new(RwLock::new(LeaderTpuCache::new(
783            start_slot,
784            slots_in_epoch,
785            leaders,
786            cluster_nodes,
787            protocol,
788        )));
789
790        let pubsub_client = if !websocket_url.is_empty() {
791            Some(PubsubClient::new(websocket_url).await?)
792        } else {
793            None
794        };
795
796        let t_leader_tpu_service = Some({
797            let recent_slots = recent_slots.clone();
798            let leader_tpu_cache = leader_tpu_cache.clone();
799            tokio::spawn(Self::run(
800                rpc_client,
801                recent_slots,
802                leader_tpu_cache,
803                pubsub_client,
804                exit,
805            ))
806        });
807
808        Ok(LeaderTpuService {
809            recent_slots,
810            leader_tpu_cache,
811            t_leader_tpu_service,
812        })
813    }
814
815    pub async fn join(&mut self) {
816        if let Some(t_handle) = self.t_leader_tpu_service.take() {
817            t_handle.await.unwrap().unwrap();
818        }
819    }
820
821    pub fn estimated_current_slot(&self) -> Slot {
822        self.recent_slots.estimated_current_slot()
823    }
824
825    pub fn unique_leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
826        let current_slot = self.recent_slots.estimated_current_slot();
827        self.leader_tpu_cache
828            .read()
829            .unwrap()
830            .get_unique_leader_sockets(current_slot, fanout_slots)
831    }
832
833    pub fn leader_tpu_sockets(&self, fanout_slots: u64) -> Vec<SocketAddr> {
834        let current_slot = self.recent_slots.estimated_current_slot();
835        self.leader_tpu_cache
836            .read()
837            .unwrap()
838            .get_leader_sockets(current_slot, fanout_slots)
839    }
840
841    async fn run(
842        rpc_client: Arc<RpcClient>,
843        recent_slots: RecentLeaderSlots,
844        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
845        pubsub_client: Option<PubsubClient>,
846        exit: Arc<AtomicBool>,
847    ) -> Result<()> {
848        tokio::try_join!(
849            Self::run_slot_watcher(recent_slots.clone(), pubsub_client, exit.clone()),
850            Self::run_cache_refresher(rpc_client, recent_slots, leader_tpu_cache, exit),
851        )?;
852
853        Ok(())
854    }
855
856    async fn run_cache_refresher(
857        rpc_client: Arc<RpcClient>,
858        recent_slots: RecentLeaderSlots,
859        leader_tpu_cache: Arc<RwLock<LeaderTpuCache>>,
860        exit: Arc<AtomicBool>,
861    ) -> Result<()> {
862        let mut last_cluster_refresh = Instant::now();
863        let mut sleep_ms = DEFAULT_MS_PER_SLOT;
864
865        while !exit.load(Ordering::Relaxed) {
866            // Sleep a slot before checking if leader cache needs to be refreshed again
867            sleep(Duration::from_millis(sleep_ms)).await;
868            sleep_ms = DEFAULT_MS_PER_SLOT;
869
870            let cache_update_info = maybe_fetch_cache_info(
871                &leader_tpu_cache,
872                last_cluster_refresh,
873                &rpc_client,
874                &recent_slots,
875            )
876            .await;
877
878            if cache_update_info.has_some() {
879                let mut leader_tpu_cache = leader_tpu_cache.write().unwrap();
880                let (has_error, cluster_refreshed) = leader_tpu_cache
881                    .update_all(recent_slots.estimated_current_slot(), cache_update_info);
882                if has_error {
883                    sleep_ms = 100;
884                }
885                if cluster_refreshed {
886                    last_cluster_refresh = Instant::now();
887                }
888            }
889        }
890
891        Ok(())
892    }
893
894    async fn run_slot_watcher(
895        recent_slots: RecentLeaderSlots,
896        pubsub_client: Option<PubsubClient>,
897        exit: Arc<AtomicBool>,
898    ) -> Result<()> {
899        let Some(pubsub_client) = pubsub_client else {
900            return Ok(());
901        };
902
903        let (mut notifications, unsubscribe) = pubsub_client.slot_updates_subscribe().await?;
904        // Time out slot update notification polling at 10ms.
905        //
906        // Rationale is two-fold:
907        // 1. Notifications are an unbounded stream -- polling them will block indefinitely if not
908        //    interrupted, and the exit condition will never be checked. 10ms ensures negligible
909        //    CPU overhead while keeping notification checking timely.
910        // 2. The timeout must be strictly less than the slot time (DEFAULT_MS_PER_SLOT: 400) to
911        //    avoid timeout never being reached. For example, if notifications are received every
912        //    400ms and the timeout is >= 400ms, notifications may theoretically always be available
913        //    before the timeout is reached, resulting in the exit condition never being checked.
914        const SLOT_UPDATE_TIMEOUT: Duration = Duration::from_millis(10);
915
916        while !exit.load(Ordering::Relaxed) {
917            while let Ok(Some(update)) = timeout(SLOT_UPDATE_TIMEOUT, notifications.next()).await {
918                let current_slot = match update {
919                    // This update indicates that a full slot was received by the connected
920                    // node so we can stop sending transactions to the leader for that slot
921                    SlotUpdate::Completed { slot, .. } => slot.saturating_add(1),
922                    // This update indicates that we have just received the first shred from
923                    // the leader for this slot and they are probably still accepting transactions.
924                    SlotUpdate::FirstShredReceived { slot, .. } => slot,
925                    _ => continue,
926                };
927                recent_slots.record_slot(current_slot);
928            }
929        }
930
931        // `notifications` requires a valid reference to `pubsub_client`, so `notifications` must be
932        // dropped before moving `pubsub_client` via `shutdown()`.
933        drop(notifications);
934        unsubscribe().await;
935        pubsub_client.shutdown().await?;
936
937        Ok(())
938    }
939}
940
941async fn maybe_fetch_cache_info(
942    leader_tpu_cache: &Arc<RwLock<LeaderTpuCache>>,
943    last_cluster_refresh: Instant,
944    rpc_client: &RpcClient,
945    recent_slots: &RecentLeaderSlots,
946) -> LeaderTpuCacheUpdateInfo {
947    // Refresh cluster TPU ports every 5min in case validators restart with new port configuration
948    // or new validators come online
949    let maybe_cluster_nodes = if last_cluster_refresh.elapsed() > Duration::from_secs(5 * 60) {
950        Some(rpc_client.get_cluster_nodes().await)
951    } else {
952        None
953    };
954
955    let estimated_current_slot = recent_slots.estimated_current_slot();
956    let (last_slot, last_epoch_info_slot, slots_in_epoch) = {
957        let leader_tpu_cache = leader_tpu_cache.read().unwrap();
958        leader_tpu_cache.slot_info()
959    };
960    let maybe_epoch_info =
961        if estimated_current_slot >= last_epoch_info_slot.saturating_sub(slots_in_epoch) {
962            Some(rpc_client.get_epoch_info().await)
963        } else {
964            None
965        };
966
967    let maybe_slot_leaders = if estimated_current_slot >= last_slot.saturating_sub(MAX_FANOUT_SLOTS)
968    {
969        Some(
970            rpc_client
971                .get_slot_leaders(
972                    estimated_current_slot,
973                    LeaderTpuCache::fanout(slots_in_epoch),
974                )
975                .await,
976        )
977    } else {
978        None
979    };
980    LeaderTpuCacheUpdateInfo {
981        maybe_cluster_nodes,
982        maybe_epoch_info,
983        maybe_slot_leaders,
984    }
985}
986
987fn is_invalid_slot_range_error(client_error: &ClientError) -> bool {
988    if let ErrorKind::RpcError(RpcError::RpcResponseError { code, message, .. }) =
989        &client_error.kind
990    {
991        return *code == -32602
992            && message.contains("Invalid slot range: leader schedule for epoch");
993    }
994    false
995}