iroh_net/dialer.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
//! A dialer to conveniently dial many nodes.
use std::{collections::HashMap, pin::Pin, task::Poll};
use anyhow::anyhow;
use futures_lite::Stream;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::error;
use crate::{Endpoint, NodeId};
/// Dials nodes and maintains a queue of pending dials.
///
/// The [`Dialer`] wraps an [`Endpoint`], connects to nodes through the endpoint, stores the
/// pending connect futures and emits finished connect results.
///
/// The [`Dialer`] also implements [`Stream`] to retrieve the dialled connections.
#[derive(Debug)]
pub struct Dialer {
endpoint: Endpoint,
pending: JoinSet<(NodeId, anyhow::Result<quinn::Connection>)>,
pending_dials: HashMap<NodeId, CancellationToken>,
}
impl Dialer {
/// Create a new dialer for a [`Endpoint`]
pub fn new(endpoint: Endpoint) -> Self {
Self {
endpoint,
pending: Default::default(),
pending_dials: Default::default(),
}
}
/// Starts to dial a node by [`NodeId`].
///
/// Since this dials by [`NodeId`] the [`Endpoint`] must know how to contact the node by
/// [`NodeId`] only. This relies on addressing information being provided by either the
/// [discovery service] or manually by calling [`Endpoint::add_node_addr`].
///
/// [discovery service]: crate::discovery::Discovery
pub fn queue_dial(&mut self, node_id: NodeId, alpn: &'static [u8]) {
if self.is_pending(node_id) {
return;
}
let cancel = CancellationToken::new();
self.pending_dials.insert(node_id, cancel.clone());
let endpoint = self.endpoint.clone();
self.pending.spawn(async move {
let res = tokio::select! {
biased;
_ = cancel.cancelled() => Err(anyhow!("Cancelled")),
res = endpoint.connect(node_id, alpn) => res
};
(node_id, res)
});
}
/// Aborts a pending dial.
pub fn abort_dial(&mut self, node_id: NodeId) {
if let Some(cancel) = self.pending_dials.remove(&node_id) {
cancel.cancel();
}
}
/// Checks if a node is currently being dialed.
pub fn is_pending(&self, node: NodeId) -> bool {
self.pending_dials.contains_key(&node)
}
/// Waits for the next dial operation to complete.
pub async fn next_conn(&mut self) -> (NodeId, anyhow::Result<quinn::Connection>) {
match self.pending_dials.is_empty() {
false => {
let (node_id, res) = loop {
match self.pending.join_next().await {
Some(Ok((node_id, res))) => {
self.pending_dials.remove(&node_id);
break (node_id, res);
}
Some(Err(e)) => {
error!("next conn error: {:?}", e);
}
None => {
error!("no more pending conns available");
std::future::pending().await
}
}
};
(node_id, res)
}
true => std::future::pending().await,
}
}
/// Number of pending connections to be opened.
pub fn pending_count(&self) -> usize {
self.pending_dials.len()
}
/// Returns a reference to the endpoint used in this dialer.
pub fn endpoint(&self) -> &Endpoint {
&self.endpoint
}
}
impl Stream for Dialer {
type Item = (NodeId, anyhow::Result<quinn::Connection>);
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.pending.poll_join_next(cx) {
Poll::Ready(Some(Ok((node_id, result)))) => {
self.pending_dials.remove(&node_id);
Poll::Ready(Some((node_id, result)))
}
Poll::Ready(Some(Err(e))) => {
error!("dialer error: {:?}", e);
Poll::Pending
}
_ => Poll::Pending,
}
}
}