use std::{
collections::{hash_map::Entry, BTreeSet, HashMap},
hash::Hash,
net::{IpAddr, SocketAddr},
pin::Pin,
task::{Context, Poll},
time::Instant,
};
use futures_lite::stream::Stream;
use iroh_base::key::NodeId;
use iroh_metrics::inc;
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use stun_rs::TransactionId;
use tracing::{debug, info, instrument, trace, warn};
use self::{
best_addr::ClearReason,
node_state::{NodeState, Options, PingHandled},
};
use super::{
metrics::Metrics as MagicsockMetrics, ActorMessage, DiscoMessageSource, QuicMappedAddr,
};
use crate::{
disco::{CallMeMaybe, Pong, SendAddr},
key::PublicKey,
relay::RelayUrl,
stun, NodeAddr,
};
mod best_addr;
mod node_state;
mod path_state;
mod udp_paths;
pub use node_state::{ConnectionType, ControlMsg, DirectAddrInfo, RemoteInfo};
pub(super) use node_state::{DiscoPingPurpose, PingAction, PingRole, SendPing};
const MAX_INACTIVE_NODES: usize = 30;
#[derive(Default, Debug)]
pub(super) struct NodeMap {
inner: Mutex<NodeMapInner>,
}
#[derive(Default, Debug)]
pub(super) struct NodeMapInner {
by_node_key: HashMap<NodeId, usize>,
by_ip_port: HashMap<IpPort, usize>,
by_quic_mapped_addr: HashMap<QuicMappedAddr, usize>,
by_id: HashMap<usize, NodeState>,
next_id: usize,
}
#[derive(Debug, Clone)]
enum NodeStateKey {
Idx(usize),
NodeId(NodeId),
QuicMappedAddr(QuicMappedAddr),
IpPort(IpPort),
}
#[derive(Serialize, Deserialize, strum::Display, Debug, Clone, Eq, PartialEq, Hash)]
#[strum(serialize_all = "kebab-case")]
pub enum Source {
Saved,
Udp,
Relay,
App,
#[strum(serialize = "{name}")]
Discovery {
name: String,
},
#[strum(serialize = "{name}")]
NamedApp {
name: String,
},
}
impl NodeMap {
pub(super) fn load_from_vec(nodes: Vec<NodeAddr>) -> Self {
Self::from_inner(NodeMapInner::load_from_vec(nodes))
}
fn from_inner(inner: NodeMapInner) -> Self {
Self {
inner: Mutex::new(inner),
}
}
pub(super) fn add_node_addr(&self, node_addr: NodeAddr, source: Source) {
self.inner.lock().add_node_addr(node_addr, source)
}
pub(super) fn node_count(&self) -> usize {
self.inner.lock().node_count()
}
pub(super) fn receive_udp(&self, udp_addr: SocketAddr) -> Option<(PublicKey, QuicMappedAddr)> {
self.inner.lock().receive_udp(udp_addr)
}
pub(super) fn receive_relay(&self, relay_url: &RelayUrl, src: NodeId) -> QuicMappedAddr {
self.inner.lock().receive_relay(relay_url, src)
}
pub(super) fn notify_ping_sent(
&self,
id: usize,
dst: SendAddr,
tx_id: stun::TransactionId,
purpose: DiscoPingPurpose,
msg_sender: tokio::sync::mpsc::Sender<ActorMessage>,
) {
if let Some(ep) = self.inner.lock().get_mut(NodeStateKey::Idx(id)) {
ep.ping_sent(dst, tx_id, purpose, msg_sender);
}
}
pub(super) fn notify_ping_timeout(&self, id: usize, tx_id: stun::TransactionId) {
if let Some(ep) = self.inner.lock().get_mut(NodeStateKey::Idx(id)) {
ep.ping_timeout(tx_id);
}
}
pub(super) fn get_quic_mapped_addr_for_node_key(
&self,
node_key: NodeId,
) -> Option<QuicMappedAddr> {
self.inner
.lock()
.get(NodeStateKey::NodeId(node_key))
.map(|ep| *ep.quic_mapped_addr())
}
pub(super) fn handle_ping(
&self,
sender: PublicKey,
src: SendAddr,
tx_id: TransactionId,
) -> PingHandled {
self.inner.lock().handle_ping(sender, src, tx_id)
}
pub(super) fn handle_pong(&self, sender: PublicKey, src: &DiscoMessageSource, pong: Pong) {
self.inner.lock().handle_pong(sender, src, pong)
}
#[must_use = "actions must be handled"]
pub(super) fn handle_call_me_maybe(
&self,
sender: PublicKey,
cm: CallMeMaybe,
) -> Vec<PingAction> {
self.inner.lock().handle_call_me_maybe(sender, cm)
}
#[allow(clippy::type_complexity)]
pub(super) fn get_send_addrs(
&self,
addr: QuicMappedAddr,
have_ipv6: bool,
) -> Option<(
PublicKey,
Option<SocketAddr>,
Option<RelayUrl>,
Vec<PingAction>,
)> {
let mut inner = self.inner.lock();
let ep = inner.get_mut(NodeStateKey::QuicMappedAddr(addr))?;
let public_key = *ep.public_key();
trace!(dest = %addr, node_id = %public_key.fmt_short(), "dst mapped to NodeId");
let (udp_addr, relay_url, msgs) = ep.get_send_addrs(have_ipv6);
Some((public_key, udp_addr, relay_url, msgs))
}
pub(super) fn notify_shutdown(&self) {
let mut inner = self.inner.lock();
for (_, ep) in inner.node_states_mut() {
ep.reset();
}
}
pub(super) fn reset_node_states(&self) {
let mut inner = self.inner.lock();
for (_, ep) in inner.node_states_mut() {
ep.note_connectivity_change();
}
}
pub(super) fn nodes_stayin_alive(&self) -> Vec<PingAction> {
let mut inner = self.inner.lock();
inner
.node_states_mut()
.flat_map(|(_idx, node_state)| node_state.stayin_alive())
.collect()
}
pub(super) fn list_remote_infos(&self, now: Instant) -> Vec<RemoteInfo> {
self.inner.lock().remote_infos_iter(now).collect()
}
pub(super) fn conn_type_stream(&self, node_id: NodeId) -> anyhow::Result<ConnectionTypeStream> {
self.inner.lock().conn_type_stream(node_id)
}
pub(super) fn remote_info(&self, node_id: NodeId) -> Option<RemoteInfo> {
self.inner.lock().remote_info(node_id)
}
pub(super) fn prune_inactive(&self) {
self.inner.lock().prune_inactive();
}
pub(crate) fn on_direct_addr_discovered(&self, discovered: BTreeSet<SocketAddr>) {
self.inner.lock().on_direct_addr_discovered(discovered);
}
}
impl NodeMapInner {
fn load_from_vec(nodes: Vec<NodeAddr>) -> Self {
let mut me = Self::default();
for node_addr in nodes {
me.add_node_addr(node_addr, Source::Saved);
}
me
}
#[instrument(skip_all, fields(node = %node_addr.node_id.fmt_short()))]
fn add_node_addr(&mut self, node_addr: NodeAddr, source: Source) {
let NodeAddr { node_id, info } = node_addr;
let source0 = source.clone();
let node_state = self.get_or_insert_with(NodeStateKey::NodeId(node_id), || Options {
node_id,
relay_url: info.relay_url.clone(),
active: false,
source,
});
node_state.update_from_node_addr(&info, source0);
let id = node_state.id();
for addr in &info.direct_addresses {
self.set_node_state_for_ip_port(*addr, id);
}
}
pub(super) fn on_direct_addr_discovered(&mut self, discovered: BTreeSet<SocketAddr>) {
for addr in discovered {
self.remove_by_ipp(addr.into(), ClearReason::MatchesOurLocalAddr)
}
}
fn remove_by_ipp(&mut self, ipp: IpPort, reason: ClearReason) {
if let Some(id) = self.by_ip_port.remove(&ipp) {
if let Entry::Occupied(mut entry) = self.by_id.entry(id) {
let node = entry.get_mut();
node.remove_direct_addr(&ipp, reason);
if node.direct_addresses().count() == 0 {
let node_id = node.public_key();
let mapped_addr = node.quic_mapped_addr();
self.by_node_key.remove(node_id);
self.by_quic_mapped_addr.remove(mapped_addr);
debug!(node_id=%node_id.fmt_short(), ?reason, "removing node");
entry.remove();
}
}
}
}
fn get_id(&self, id: NodeStateKey) -> Option<usize> {
match id {
NodeStateKey::Idx(id) => Some(id),
NodeStateKey::NodeId(node_key) => self.by_node_key.get(&node_key).copied(),
NodeStateKey::QuicMappedAddr(addr) => self.by_quic_mapped_addr.get(&addr).copied(),
NodeStateKey::IpPort(ipp) => self.by_ip_port.get(&ipp).copied(),
}
}
fn get_mut(&mut self, id: NodeStateKey) -> Option<&mut NodeState> {
self.get_id(id).and_then(|id| self.by_id.get_mut(&id))
}
fn get(&self, id: NodeStateKey) -> Option<&NodeState> {
self.get_id(id).and_then(|id| self.by_id.get(&id))
}
fn get_or_insert_with(
&mut self,
id: NodeStateKey,
f: impl FnOnce() -> Options,
) -> &mut NodeState {
let id = self.get_id(id);
match id {
None => self.insert_node(f()),
Some(id) => self.by_id.get_mut(&id).expect("is not empty"),
}
}
fn node_count(&self) -> usize {
self.by_id.len()
}
fn receive_udp(&mut self, udp_addr: SocketAddr) -> Option<(NodeId, QuicMappedAddr)> {
let ip_port: IpPort = udp_addr.into();
let Some(node_state) = self.get_mut(NodeStateKey::IpPort(ip_port)) else {
info!(src=%udp_addr, "receive_udp: no node_state found for addr, ignore");
return None;
};
node_state.receive_udp(ip_port, Instant::now());
Some((*node_state.public_key(), *node_state.quic_mapped_addr()))
}
#[instrument(skip_all, fields(src = %src.fmt_short()))]
fn receive_relay(&mut self, relay_url: &RelayUrl, src: NodeId) -> QuicMappedAddr {
let node_state = self.get_or_insert_with(NodeStateKey::NodeId(src), || {
trace!("packets from unknown node, insert into node map");
Options {
node_id: src,
relay_url: Some(relay_url.clone()),
active: true,
source: Source::Relay,
}
});
node_state.receive_relay(relay_url, src, Instant::now());
*node_state.quic_mapped_addr()
}
fn node_states(&self) -> impl Iterator<Item = (&usize, &NodeState)> {
self.by_id.iter()
}
fn node_states_mut(&mut self) -> impl Iterator<Item = (&usize, &mut NodeState)> {
self.by_id.iter_mut()
}
fn remote_infos_iter(&self, now: Instant) -> impl Iterator<Item = RemoteInfo> + '_ {
self.node_states().map(move |(_, ep)| ep.info(now))
}
fn remote_info(&self, node_id: NodeId) -> Option<RemoteInfo> {
self.get(NodeStateKey::NodeId(node_id))
.map(|ep| ep.info(Instant::now()))
}
fn conn_type_stream(&self, node_id: NodeId) -> anyhow::Result<ConnectionTypeStream> {
match self.get(NodeStateKey::NodeId(node_id)) {
Some(ep) => Ok(ConnectionTypeStream {
initial: Some(ep.conn_type()),
inner: ep.conn_type_stream(),
}),
None => anyhow::bail!("No endpoint for {node_id:?} found"),
}
}
fn handle_pong(&mut self, sender: NodeId, src: &DiscoMessageSource, pong: Pong) {
if let Some(ns) = self.get_mut(NodeStateKey::NodeId(sender)).as_mut() {
let insert = ns.handle_pong(&pong, src.into());
if let Some((src, key)) = insert {
self.set_node_key_for_ip_port(src, &key);
}
trace!(?insert, "received pong")
} else {
warn!("received pong: node unknown, ignore")
}
}
#[must_use = "actions must be handled"]
fn handle_call_me_maybe(&mut self, sender: NodeId, cm: CallMeMaybe) -> Vec<PingAction> {
let ns_id = NodeStateKey::NodeId(sender);
if let Some(id) = self.get_id(ns_id.clone()) {
for number in &cm.my_numbers {
self.set_node_state_for_ip_port(*number, id);
}
}
match self.get_mut(ns_id) {
None => {
inc!(MagicsockMetrics, recv_disco_call_me_maybe_bad_disco);
debug!("received call-me-maybe: ignore, node is unknown");
vec![]
}
Some(ns) => {
debug!(endpoints = ?cm.my_numbers, "received call-me-maybe");
ns.handle_call_me_maybe(cm)
}
}
}
fn handle_ping(&mut self, sender: NodeId, src: SendAddr, tx_id: TransactionId) -> PingHandled {
let node_state = self.get_or_insert_with(NodeStateKey::NodeId(sender), || {
debug!("received ping: node unknown, add to node map");
let source = if src.is_relay() {
Source::Relay
} else {
Source::Udp
};
Options {
node_id: sender,
relay_url: src.relay_url(),
active: true,
source,
}
});
let handled = node_state.handle_ping(src.clone(), tx_id);
if let SendAddr::Udp(ref addr) = src {
if matches!(handled.role, PingRole::NewPath) {
self.set_node_key_for_ip_port(*addr, &sender);
}
}
handled
}
fn insert_node(&mut self, options: Options) -> &mut NodeState {
info!(
node = %options.node_id.fmt_short(),
relay_url = ?options.relay_url,
source = %options.source,
"inserting new node in NodeMap",
);
let id = self.next_id;
self.next_id = self.next_id.wrapping_add(1);
let node_state = NodeState::new(id, options);
self.by_quic_mapped_addr
.insert(*node_state.quic_mapped_addr(), id);
self.by_node_key.insert(*node_state.public_key(), id);
self.by_id.insert(id, node_state);
self.by_id.get_mut(&id).expect("just inserted")
}
fn set_node_key_for_ip_port(&mut self, ipp: impl Into<IpPort>, nk: &PublicKey) {
let ipp = ipp.into();
if let Some(id) = self.by_ip_port.get(&ipp) {
if !self.by_node_key.contains_key(nk) {
self.by_node_key.insert(*nk, *id);
}
self.by_ip_port.remove(&ipp);
}
if let Some(id) = self.by_node_key.get(nk) {
trace!("insert ip -> id: {:?} -> {}", ipp, id);
self.by_ip_port.insert(ipp, *id);
}
}
fn set_node_state_for_ip_port(&mut self, ipp: impl Into<IpPort>, id: usize) {
let ipp = ipp.into();
trace!(?ipp, ?id, "set endpoint for ip:port");
self.by_ip_port.insert(ipp, id);
}
fn prune_inactive(&mut self) {
let now = Instant::now();
let mut prune_candidates: Vec<_> = self
.by_id
.values()
.filter(|node| !node.is_active(&now))
.map(|node| (*node.public_key(), node.last_used()))
.collect();
let prune_count = prune_candidates.len().saturating_sub(MAX_INACTIVE_NODES);
if prune_count == 0 {
return;
}
prune_candidates.sort_unstable_by_key(|(_pk, last_used)| *last_used);
prune_candidates.truncate(prune_count);
for (public_key, last_used) in prune_candidates.into_iter() {
let node = public_key.fmt_short();
match last_used.map(|instant| instant.elapsed()) {
Some(last_used) => trace!(%node, ?last_used, "pruning inactive"),
None => trace!(%node, last_used=%"never", "pruning inactive"),
}
let Some(id) = self.by_node_key.remove(&public_key) else {
debug_assert!(false, "missing by_node_key entry for pk in by_id");
continue;
};
let Some(ep) = self.by_id.remove(&id) else {
debug_assert!(false, "missing by_id entry for id in by_node_key");
continue;
};
for ip_port in ep.direct_addresses() {
self.by_ip_port.remove(&ip_port);
}
self.by_quic_mapped_addr.remove(ep.quic_mapped_addr());
}
}
}
#[derive(Debug)]
pub struct ConnectionTypeStream {
initial: Option<ConnectionType>,
inner: watchable::WatcherStream<ConnectionType>,
}
impl Stream for ConnectionTypeStream {
type Item = ConnectionType;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if let Some(initial_conn_type) = this.initial.take() {
return Poll::Ready(Some(initial_conn_type));
}
Pin::new(&mut this.inner).poll_next(cx)
}
}
#[derive(Debug, derive_more::Display, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord)]
#[display("{}", SocketAddr::from(*self))]
pub struct IpPort {
ip: IpAddr,
port: u16,
}
impl From<SocketAddr> for IpPort {
fn from(socket_addr: SocketAddr) -> Self {
Self {
ip: socket_addr.ip(),
port: socket_addr.port(),
}
}
}
impl From<IpPort> for SocketAddr {
fn from(ip_port: IpPort) -> Self {
let IpPort { ip, port } = ip_port;
(ip, port).into()
}
}
impl IpPort {
pub fn ip(&self) -> &IpAddr {
&self.ip
}
pub fn port(&self) -> u16 {
self.port
}
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use super::{node_state::MAX_INACTIVE_DIRECT_ADDRESSES, *};
use crate::key::SecretKey;
impl NodeMap {
#[track_caller]
fn add_test_addr(&self, node_addr: NodeAddr) {
self.add_node_addr(
node_addr,
Source::NamedApp {
name: "test".into(),
},
)
}
}
#[tokio::test]
async fn restore_from_vec() {
let _guard = iroh_test::logging::setup();
let node_map = NodeMap::default();
let node_a = SecretKey::generate().public();
let node_b = SecretKey::generate().public();
let node_c = SecretKey::generate().public();
let node_d = SecretKey::generate().public();
let relay_x: RelayUrl = "https://my-relay-1.com".parse().unwrap();
let relay_y: RelayUrl = "https://my-relay-2.com".parse().unwrap();
let direct_addresses_a = [addr(4000), addr(4001)];
let direct_addresses_c = [addr(5000)];
let node_addr_a = NodeAddr::new(node_a)
.with_relay_url(relay_x)
.with_direct_addresses(direct_addresses_a);
let node_addr_b = NodeAddr::new(node_b).with_relay_url(relay_y);
let node_addr_c = NodeAddr::new(node_c).with_direct_addresses(direct_addresses_c);
let node_addr_d = NodeAddr::new(node_d);
node_map.add_test_addr(node_addr_a);
node_map.add_test_addr(node_addr_b);
node_map.add_test_addr(node_addr_c);
node_map.add_test_addr(node_addr_d);
let mut addrs: Vec<NodeAddr> = node_map
.list_remote_infos(Instant::now())
.into_iter()
.filter_map(|info| {
let addr: NodeAddr = info.into();
if addr.info.is_empty() {
return None;
}
Some(addr)
})
.collect();
let loaded_node_map = NodeMap::load_from_vec(addrs.clone());
let mut loaded: Vec<NodeAddr> = loaded_node_map
.list_remote_infos(Instant::now())
.into_iter()
.filter_map(|info| {
let addr: NodeAddr = info.into();
if addr.info.is_empty() {
return None;
}
Some(addr)
})
.collect();
loaded.sort_unstable();
addrs.sort_unstable();
assert_eq!(addrs, loaded);
}
fn addr(port: u16) -> SocketAddr {
(std::net::IpAddr::V4(Ipv4Addr::LOCALHOST), port).into()
}
#[test]
fn test_prune_direct_addresses() {
let _guard = iroh_test::logging::setup();
let node_map = NodeMap::default();
let public_key = SecretKey::generate().public();
let id = node_map
.inner
.lock()
.insert_node(Options {
node_id: public_key,
relay_url: None,
active: false,
source: Source::NamedApp {
name: "test".into(),
},
})
.id();
const LOCALHOST: IpAddr = IpAddr::V4(std::net::Ipv4Addr::LOCALHOST);
info!("Adding active addresses");
for i in 0..MAX_INACTIVE_DIRECT_ADDRESSES {
let addr = SocketAddr::new(LOCALHOST, 5000 + i as u16);
let node_addr = NodeAddr::new(public_key).with_direct_addresses([addr]);
node_map.add_test_addr(node_addr);
node_map.inner.lock().receive_udp(addr);
}
info!("Adding offline/inactive addresses");
for i in 0..MAX_INACTIVE_DIRECT_ADDRESSES * 2 {
let addr = SocketAddr::new(LOCALHOST, 6000 + i as u16);
let node_addr = NodeAddr::new(public_key).with_direct_addresses([addr]);
node_map.add_test_addr(node_addr);
}
let mut node_map_inner = node_map.inner.lock();
let endpoint = node_map_inner.by_id.get_mut(&id).unwrap();
info!("Adding alive addresses");
for i in 0..MAX_INACTIVE_DIRECT_ADDRESSES {
let addr = SendAddr::Udp(SocketAddr::new(LOCALHOST, 7000 + i as u16));
let txid = stun::TransactionId::from([i as u8; 12]);
endpoint.handle_ping(addr, txid);
}
info!("Pruning addresses");
endpoint.prune_direct_addresses();
assert_eq!(
endpoint.direct_addresses().count(),
MAX_INACTIVE_DIRECT_ADDRESSES * 3
);
assert_eq!(
endpoint
.direct_address_states()
.filter(|(_addr, state)| !state.is_active())
.count(),
MAX_INACTIVE_DIRECT_ADDRESSES * 2
)
}
#[test]
fn test_prune_inactive() {
let node_map = NodeMap::default();
let active_node = SecretKey::generate().public();
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 167);
node_map.add_test_addr(NodeAddr::new(active_node).with_direct_addresses([addr]));
node_map.inner.lock().receive_udp(addr).expect("registered");
for _ in 0..MAX_INACTIVE_NODES + 1 {
let node = SecretKey::generate().public();
node_map.add_test_addr(NodeAddr::new(node));
}
assert_eq!(node_map.node_count(), MAX_INACTIVE_NODES + 2);
node_map.prune_inactive();
assert_eq!(node_map.node_count(), MAX_INACTIVE_NODES + 1);
node_map
.inner
.lock()
.get(NodeStateKey::NodeId(active_node))
.expect("should not be pruned");
}
}