iroh_net/
dialer.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//! A dialer to conveniently dial many nodes.

use std::{collections::HashMap, pin::Pin, task::Poll};

use anyhow::anyhow;
use futures_lite::Stream;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::error;

use crate::{Endpoint, NodeId};

/// Dials nodes and maintains a queue of pending dials.
///
/// The [`Dialer`] wraps an [`Endpoint`], connects to nodes through the endpoint, stores the
/// pending connect futures and emits finished connect results.
///
/// The [`Dialer`] also implements [`Stream`] to retrieve the dialled connections.
#[derive(Debug)]
pub struct Dialer {
    endpoint: Endpoint,
    pending: JoinSet<(NodeId, anyhow::Result<quinn::Connection>)>,
    pending_dials: HashMap<NodeId, CancellationToken>,
}

impl Dialer {
    /// Create a new dialer for a [`Endpoint`]
    pub fn new(endpoint: Endpoint) -> Self {
        Self {
            endpoint,
            pending: Default::default(),
            pending_dials: Default::default(),
        }
    }

    /// Starts to dial a node by [`NodeId`].
    ///
    /// Since this dials by [`NodeId`] the [`Endpoint`] must know how to contact the node by
    /// [`NodeId`] only.  This relies on addressing information being provided by either the
    /// [discovery service] or manually by calling [`Endpoint::add_node_addr`].
    ///
    /// [discovery service]: crate::discovery::Discovery
    pub fn queue_dial(&mut self, node_id: NodeId, alpn: &'static [u8]) {
        if self.is_pending(node_id) {
            return;
        }
        let cancel = CancellationToken::new();
        self.pending_dials.insert(node_id, cancel.clone());
        let endpoint = self.endpoint.clone();
        self.pending.spawn(async move {
            let res = tokio::select! {
                biased;
                _ = cancel.cancelled() => Err(anyhow!("Cancelled")),
                res = endpoint.connect(node_id, alpn) => res
            };
            (node_id, res)
        });
    }

    /// Aborts a pending dial.
    pub fn abort_dial(&mut self, node_id: NodeId) {
        if let Some(cancel) = self.pending_dials.remove(&node_id) {
            cancel.cancel();
        }
    }

    /// Checks if a node is currently being dialed.
    pub fn is_pending(&self, node: NodeId) -> bool {
        self.pending_dials.contains_key(&node)
    }

    /// Waits for the next dial operation to complete.
    pub async fn next_conn(&mut self) -> (NodeId, anyhow::Result<quinn::Connection>) {
        match self.pending_dials.is_empty() {
            false => {
                let (node_id, res) = loop {
                    match self.pending.join_next().await {
                        Some(Ok((node_id, res))) => {
                            self.pending_dials.remove(&node_id);
                            break (node_id, res);
                        }
                        Some(Err(e)) => {
                            error!("next conn error: {:?}", e);
                        }
                        None => {
                            error!("no more pending conns available");
                            std::future::pending().await
                        }
                    }
                };

                (node_id, res)
            }
            true => std::future::pending().await,
        }
    }

    /// Number of pending connections to be opened.
    pub fn pending_count(&self) -> usize {
        self.pending_dials.len()
    }

    /// Returns a reference to the endpoint used in this dialer.
    pub fn endpoint(&self) -> &Endpoint {
        &self.endpoint
    }
}

impl Stream for Dialer {
    type Item = (NodeId, anyhow::Result<quinn::Connection>);

    fn poll_next(
        mut self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match self.pending.poll_join_next(cx) {
            Poll::Ready(Some(Ok((node_id, result)))) => {
                self.pending_dials.remove(&node_id);
                Poll::Ready(Some((node_id, result)))
            }
            Poll::Ready(Some(Err(e))) => {
                error!("dialer error: {:?}", e);
                Poll::Pending
            }
            _ => Poll::Pending,
        }
    }
}