libp2p_mdns/
behaviour.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod iface;
22mod socket;
23mod timer;
24
25use std::{
26    cmp,
27    collections::{
28        hash_map::{Entry, HashMap},
29        VecDeque,
30    },
31    convert::Infallible,
32    fmt,
33    future::Future,
34    io,
35    net::IpAddr,
36    pin::Pin,
37    sync::{Arc, RwLock},
38    task::{Context, Poll},
39    time::Instant,
40};
41
42use futures::{channel::mpsc, Stream, StreamExt};
43use if_watch::IfEvent;
44use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
45use libp2p_identity::PeerId;
46use libp2p_swarm::{
47    behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour,
48    THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
49};
50use smallvec::SmallVec;
51
52use self::iface::InterfaceState;
53use crate::{
54    behaviour::{socket::AsyncSocket, timer::Builder},
55    Config,
56};
57
58/// An abstraction to allow for compatibility with various async runtimes.
59pub trait Provider: 'static {
60    /// The Async Socket type.
61    type Socket: AsyncSocket;
62    /// The Async Timer type.
63    type Timer: Builder + Stream;
64    /// The IfWatcher type.
65    type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
66
67    type TaskHandle: Abort;
68
69    /// Create a new instance of the `IfWatcher` type.
70    fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
71
72    #[track_caller]
73    fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
74}
75
76#[allow(unreachable_pub)] // Not re-exported.
77pub trait Abort {
78    fn abort(self);
79}
80
81/// The type of a [`Behaviour`] using the `async-io` implementation.
82#[cfg(feature = "async-io")]
83pub mod async_io {
84    use std::future::Future;
85
86    use async_std::task::JoinHandle;
87    use if_watch::smol::IfWatcher;
88
89    use super::Provider;
90    use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
91
92    #[doc(hidden)]
93    pub enum AsyncIo {}
94
95    impl Provider for AsyncIo {
96        type Socket = AsyncUdpSocket;
97        type Timer = AsyncTimer;
98        type Watcher = IfWatcher;
99        type TaskHandle = JoinHandle<()>;
100
101        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
102            IfWatcher::new()
103        }
104
105        fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
106            async_std::task::spawn(task)
107        }
108    }
109
110    impl Abort for JoinHandle<()> {
111        fn abort(self) {
112            async_std::task::spawn(self.cancel());
113        }
114    }
115
116    pub type Behaviour = super::Behaviour<AsyncIo>;
117}
118
119/// The type of a [`Behaviour`] using the `tokio` implementation.
120#[cfg(feature = "tokio")]
121pub mod tokio {
122    use std::future::Future;
123
124    use if_watch::tokio::IfWatcher;
125    use tokio::task::JoinHandle;
126
127    use super::Provider;
128    use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
129
130    #[doc(hidden)]
131    pub enum Tokio {}
132
133    impl Provider for Tokio {
134        type Socket = TokioUdpSocket;
135        type Timer = TokioTimer;
136        type Watcher = IfWatcher;
137        type TaskHandle = JoinHandle<()>;
138
139        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
140            IfWatcher::new()
141        }
142
143        fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
144            tokio::spawn(task)
145        }
146    }
147
148    impl Abort for JoinHandle<()> {
149        fn abort(self) {
150            JoinHandle::abort(&self)
151        }
152    }
153
154    pub type Behaviour = super::Behaviour<Tokio>;
155}
156
157/// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds
158/// them to the topology.
159#[derive(Debug)]
160pub struct Behaviour<P>
161where
162    P: Provider,
163{
164    /// InterfaceState config.
165    config: Config,
166
167    /// Iface watcher.
168    if_watch: P::Watcher,
169
170    /// Handles to tasks running the mDNS queries.
171    if_tasks: HashMap<IpAddr, P::TaskHandle>,
172
173    query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
174    query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
175
176    /// List of nodes that we have discovered, the address, and when their TTL expires.
177    ///
178    /// Each combination of `PeerId` and `Multiaddr` can only appear once, but the same `PeerId`
179    /// can appear multiple times.
180    discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
181
182    /// Future that fires when the TTL of at least one node in `discovered_nodes` expires.
183    ///
184    /// `None` if `discovered_nodes` is empty.
185    closest_expiration: Option<P::Timer>,
186
187    /// The current set of listen addresses.
188    ///
189    /// This is shared across all interface tasks using an [`RwLock`].
190    /// The [`Behaviour`] updates this upon new [`FromSwarm`]
191    /// events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
192    listen_addresses: Arc<RwLock<ListenAddresses>>,
193
194    local_peer_id: PeerId,
195
196    /// Pending behaviour events to be emitted.
197    pending_events: VecDeque<ToSwarm<Event, Infallible>>,
198}
199
200impl<P> Behaviour<P>
201where
202    P: Provider,
203{
204    /// Builds a new `Mdns` behaviour.
205    pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
206        let (tx, rx) = mpsc::channel(10); // Chosen arbitrarily.
207
208        Ok(Self {
209            config,
210            if_watch: P::new_watcher()?,
211            if_tasks: Default::default(),
212            query_response_receiver: rx,
213            query_response_sender: tx,
214            discovered_nodes: Default::default(),
215            closest_expiration: Default::default(),
216            listen_addresses: Default::default(),
217            local_peer_id,
218            pending_events: Default::default(),
219        })
220    }
221
222    /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
223    #[deprecated(note = "Use `discovered_nodes` iterator instead.")]
224    pub fn has_node(&self, peer_id: &PeerId) -> bool {
225        self.discovered_nodes().any(|p| p == peer_id)
226    }
227
228    /// Returns the list of nodes that we have discovered through mDNS and that are not expired.
229    pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
230        self.discovered_nodes.iter().map(|(p, _, _)| p)
231    }
232
233    /// Expires a node before the ttl.
234    #[deprecated(note = "Unused API. Will be removed in the next release.")]
235    pub fn expire_node(&mut self, peer_id: &PeerId) {
236        let now = Instant::now();
237        for (peer, _addr, expires) in &mut self.discovered_nodes {
238            if peer == peer_id {
239                *expires = now;
240            }
241        }
242        self.closest_expiration = Some(P::Timer::at(now));
243    }
244}
245
246impl<P> NetworkBehaviour for Behaviour<P>
247where
248    P: Provider,
249{
250    type ConnectionHandler = dummy::ConnectionHandler;
251    type ToSwarm = Event;
252
253    fn handle_established_inbound_connection(
254        &mut self,
255        _: ConnectionId,
256        _: PeerId,
257        _: &Multiaddr,
258        _: &Multiaddr,
259    ) -> Result<THandler<Self>, ConnectionDenied> {
260        Ok(dummy::ConnectionHandler)
261    }
262
263    fn handle_pending_outbound_connection(
264        &mut self,
265        _connection_id: ConnectionId,
266        maybe_peer: Option<PeerId>,
267        _addresses: &[Multiaddr],
268        _effective_role: Endpoint,
269    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
270        let peer_id = match maybe_peer {
271            None => return Ok(vec![]),
272            Some(peer) => peer,
273        };
274
275        Ok(self
276            .discovered_nodes
277            .iter()
278            .filter(|(peer, _, _)| peer == &peer_id)
279            .map(|(_, addr, _)| addr.clone())
280            .collect())
281    }
282
283    fn handle_established_outbound_connection(
284        &mut self,
285        _: ConnectionId,
286        _: PeerId,
287        _: &Multiaddr,
288        _: Endpoint,
289        _: PortUse,
290    ) -> Result<THandler<Self>, ConnectionDenied> {
291        Ok(dummy::ConnectionHandler)
292    }
293
294    fn on_connection_handler_event(
295        &mut self,
296        _: PeerId,
297        _: ConnectionId,
298        ev: THandlerOutEvent<Self>,
299    ) {
300        libp2p_core::util::unreachable(ev)
301    }
302
303    fn on_swarm_event(&mut self, event: FromSwarm) {
304        self.listen_addresses
305            .write()
306            .unwrap_or_else(|e| e.into_inner())
307            .on_swarm_event(&event);
308    }
309
310    #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
311    fn poll(
312        &mut self,
313        cx: &mut Context<'_>,
314    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
315        loop {
316            // Check for pending events and emit them.
317            if let Some(event) = self.pending_events.pop_front() {
318                return Poll::Ready(event);
319            }
320
321            // Poll ifwatch.
322            while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
323                match event {
324                    Ok(IfEvent::Up(inet)) => {
325                        let addr = inet.addr();
326                        if addr.is_loopback() {
327                            continue;
328                        }
329                        if addr.is_ipv4() && self.config.enable_ipv6
330                            || addr.is_ipv6() && !self.config.enable_ipv6
331                        {
332                            continue;
333                        }
334                        if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
335                            match InterfaceState::<P::Socket, P::Timer>::new(
336                                addr,
337                                self.config.clone(),
338                                self.local_peer_id,
339                                self.listen_addresses.clone(),
340                                self.query_response_sender.clone(),
341                            ) {
342                                Ok(iface_state) => {
343                                    e.insert(P::spawn(iface_state));
344                                }
345                                Err(err) => {
346                                    tracing::error!("failed to create `InterfaceState`: {}", err)
347                                }
348                            }
349                        }
350                    }
351                    Ok(IfEvent::Down(inet)) => {
352                        if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
353                            tracing::info!(instance=%inet.addr(), "dropping instance");
354
355                            handle.abort();
356                        }
357                    }
358                    Err(err) => tracing::error!("if watch returned an error: {}", err),
359                }
360            }
361            // Emit discovered event.
362            let mut discovered = Vec::new();
363
364            while let Poll::Ready(Some((peer, addr, expiration))) =
365                self.query_response_receiver.poll_next_unpin(cx)
366            {
367                if let Some((_, _, cur_expires)) = self
368                    .discovered_nodes
369                    .iter_mut()
370                    .find(|(p, a, _)| *p == peer && *a == addr)
371                {
372                    *cur_expires = cmp::max(*cur_expires, expiration);
373                } else {
374                    tracing::info!(%peer, address=%addr, "discovered peer on address");
375                    self.discovered_nodes.push((peer, addr.clone(), expiration));
376                    discovered.push((peer, addr.clone()));
377
378                    self.pending_events
379                        .push_back(ToSwarm::NewExternalAddrOfPeer {
380                            peer_id: peer,
381                            address: addr,
382                        });
383                }
384            }
385
386            if !discovered.is_empty() {
387                let event = Event::Discovered(discovered);
388                // Push to the front of the queue so that the behavior event is reported before
389                // the individual discovered addresses.
390                self.pending_events
391                    .push_front(ToSwarm::GenerateEvent(event));
392                continue;
393            }
394            // Emit expired event.
395            let now = Instant::now();
396            let mut closest_expiration = None;
397            let mut expired = Vec::new();
398            self.discovered_nodes.retain(|(peer, addr, expiration)| {
399                if *expiration <= now {
400                    tracing::info!(%peer, address=%addr, "expired peer on address");
401                    expired.push((*peer, addr.clone()));
402                    return false;
403                }
404                closest_expiration =
405                    Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
406                true
407            });
408            if !expired.is_empty() {
409                let event = Event::Expired(expired);
410                self.pending_events.push_back(ToSwarm::GenerateEvent(event));
411                continue;
412            }
413            if let Some(closest_expiration) = closest_expiration {
414                let mut timer = P::Timer::at(closest_expiration);
415                let _ = Pin::new(&mut timer).poll_next(cx);
416
417                self.closest_expiration = Some(timer);
418            }
419
420            return Poll::Pending;
421        }
422    }
423}
424
425/// Event that can be produced by the `Mdns` behaviour.
426#[derive(Debug, Clone)]
427pub enum Event {
428    /// Discovered nodes through mDNS.
429    Discovered(Vec<(PeerId, Multiaddr)>),
430
431    /// The given combinations of `PeerId` and `Multiaddr` have expired.
432    ///
433    /// Each discovered record has a time-to-live. When this TTL expires and the address hasn't
434    /// been refreshed, we remove it from the list and emit it as an `Expired` event.
435    Expired(Vec<(PeerId, Multiaddr)>),
436}