1mod iface;
22mod socket;
23mod timer;
24
25use std::{
26 cmp,
27 collections::{
28 hash_map::{Entry, HashMap},
29 VecDeque,
30 },
31 convert::Infallible,
32 fmt,
33 future::Future,
34 io,
35 net::IpAddr,
36 pin::Pin,
37 sync::{Arc, RwLock},
38 task::{Context, Poll},
39 time::Instant,
40};
41
42use futures::{channel::mpsc, Stream, StreamExt};
43use if_watch::IfEvent;
44use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
45use libp2p_identity::PeerId;
46use libp2p_swarm::{
47 behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour,
48 THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
49};
50use smallvec::SmallVec;
51
52use self::iface::InterfaceState;
53use crate::{
54 behaviour::{socket::AsyncSocket, timer::Builder},
55 Config,
56};
57
58pub trait Provider: 'static {
60 type Socket: AsyncSocket;
62 type Timer: Builder + Stream;
64 type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
66
67 type TaskHandle: Abort;
68
69 fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
71
72 #[track_caller]
73 fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
74}
75
76#[allow(unreachable_pub)] pub trait Abort {
78 fn abort(self);
79}
80
81#[cfg(feature = "async-io")]
83pub mod async_io {
84 use std::future::Future;
85
86 use async_std::task::JoinHandle;
87 use if_watch::smol::IfWatcher;
88
89 use super::Provider;
90 use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
91
92 #[doc(hidden)]
93 pub enum AsyncIo {}
94
95 impl Provider for AsyncIo {
96 type Socket = AsyncUdpSocket;
97 type Timer = AsyncTimer;
98 type Watcher = IfWatcher;
99 type TaskHandle = JoinHandle<()>;
100
101 fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
102 IfWatcher::new()
103 }
104
105 fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
106 async_std::task::spawn(task)
107 }
108 }
109
110 impl Abort for JoinHandle<()> {
111 fn abort(self) {
112 async_std::task::spawn(self.cancel());
113 }
114 }
115
116 pub type Behaviour = super::Behaviour<AsyncIo>;
117}
118
119#[cfg(feature = "tokio")]
121pub mod tokio {
122 use std::future::Future;
123
124 use if_watch::tokio::IfWatcher;
125 use tokio::task::JoinHandle;
126
127 use super::Provider;
128 use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
129
130 #[doc(hidden)]
131 pub enum Tokio {}
132
133 impl Provider for Tokio {
134 type Socket = TokioUdpSocket;
135 type Timer = TokioTimer;
136 type Watcher = IfWatcher;
137 type TaskHandle = JoinHandle<()>;
138
139 fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
140 IfWatcher::new()
141 }
142
143 fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
144 tokio::spawn(task)
145 }
146 }
147
148 impl Abort for JoinHandle<()> {
149 fn abort(self) {
150 JoinHandle::abort(&self)
151 }
152 }
153
154 pub type Behaviour = super::Behaviour<Tokio>;
155}
156
157#[derive(Debug)]
160pub struct Behaviour<P>
161where
162 P: Provider,
163{
164 config: Config,
166
167 if_watch: P::Watcher,
169
170 if_tasks: HashMap<IpAddr, P::TaskHandle>,
172
173 query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
174 query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
175
176 discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
181
182 closest_expiration: Option<P::Timer>,
186
187 listen_addresses: Arc<RwLock<ListenAddresses>>,
193
194 local_peer_id: PeerId,
195
196 pending_events: VecDeque<ToSwarm<Event, Infallible>>,
198}
199
200impl<P> Behaviour<P>
201where
202 P: Provider,
203{
204 pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
206 let (tx, rx) = mpsc::channel(10); Ok(Self {
209 config,
210 if_watch: P::new_watcher()?,
211 if_tasks: Default::default(),
212 query_response_receiver: rx,
213 query_response_sender: tx,
214 discovered_nodes: Default::default(),
215 closest_expiration: Default::default(),
216 listen_addresses: Default::default(),
217 local_peer_id,
218 pending_events: Default::default(),
219 })
220 }
221
222 #[deprecated(note = "Use `discovered_nodes` iterator instead.")]
224 pub fn has_node(&self, peer_id: &PeerId) -> bool {
225 self.discovered_nodes().any(|p| p == peer_id)
226 }
227
228 pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
230 self.discovered_nodes.iter().map(|(p, _, _)| p)
231 }
232
233 #[deprecated(note = "Unused API. Will be removed in the next release.")]
235 pub fn expire_node(&mut self, peer_id: &PeerId) {
236 let now = Instant::now();
237 for (peer, _addr, expires) in &mut self.discovered_nodes {
238 if peer == peer_id {
239 *expires = now;
240 }
241 }
242 self.closest_expiration = Some(P::Timer::at(now));
243 }
244}
245
246impl<P> NetworkBehaviour for Behaviour<P>
247where
248 P: Provider,
249{
250 type ConnectionHandler = dummy::ConnectionHandler;
251 type ToSwarm = Event;
252
253 fn handle_established_inbound_connection(
254 &mut self,
255 _: ConnectionId,
256 _: PeerId,
257 _: &Multiaddr,
258 _: &Multiaddr,
259 ) -> Result<THandler<Self>, ConnectionDenied> {
260 Ok(dummy::ConnectionHandler)
261 }
262
263 fn handle_pending_outbound_connection(
264 &mut self,
265 _connection_id: ConnectionId,
266 maybe_peer: Option<PeerId>,
267 _addresses: &[Multiaddr],
268 _effective_role: Endpoint,
269 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
270 let peer_id = match maybe_peer {
271 None => return Ok(vec![]),
272 Some(peer) => peer,
273 };
274
275 Ok(self
276 .discovered_nodes
277 .iter()
278 .filter(|(peer, _, _)| peer == &peer_id)
279 .map(|(_, addr, _)| addr.clone())
280 .collect())
281 }
282
283 fn handle_established_outbound_connection(
284 &mut self,
285 _: ConnectionId,
286 _: PeerId,
287 _: &Multiaddr,
288 _: Endpoint,
289 _: PortUse,
290 ) -> Result<THandler<Self>, ConnectionDenied> {
291 Ok(dummy::ConnectionHandler)
292 }
293
294 fn on_connection_handler_event(
295 &mut self,
296 _: PeerId,
297 _: ConnectionId,
298 ev: THandlerOutEvent<Self>,
299 ) {
300 libp2p_core::util::unreachable(ev)
301 }
302
303 fn on_swarm_event(&mut self, event: FromSwarm) {
304 self.listen_addresses
305 .write()
306 .unwrap_or_else(|e| e.into_inner())
307 .on_swarm_event(&event);
308 }
309
310 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
311 fn poll(
312 &mut self,
313 cx: &mut Context<'_>,
314 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
315 loop {
316 if let Some(event) = self.pending_events.pop_front() {
318 return Poll::Ready(event);
319 }
320
321 while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
323 match event {
324 Ok(IfEvent::Up(inet)) => {
325 let addr = inet.addr();
326 if addr.is_loopback() {
327 continue;
328 }
329 if addr.is_ipv4() && self.config.enable_ipv6
330 || addr.is_ipv6() && !self.config.enable_ipv6
331 {
332 continue;
333 }
334 if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
335 match InterfaceState::<P::Socket, P::Timer>::new(
336 addr,
337 self.config.clone(),
338 self.local_peer_id,
339 self.listen_addresses.clone(),
340 self.query_response_sender.clone(),
341 ) {
342 Ok(iface_state) => {
343 e.insert(P::spawn(iface_state));
344 }
345 Err(err) => {
346 tracing::error!("failed to create `InterfaceState`: {}", err)
347 }
348 }
349 }
350 }
351 Ok(IfEvent::Down(inet)) => {
352 if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
353 tracing::info!(instance=%inet.addr(), "dropping instance");
354
355 handle.abort();
356 }
357 }
358 Err(err) => tracing::error!("if watch returned an error: {}", err),
359 }
360 }
361 let mut discovered = Vec::new();
363
364 while let Poll::Ready(Some((peer, addr, expiration))) =
365 self.query_response_receiver.poll_next_unpin(cx)
366 {
367 if let Some((_, _, cur_expires)) = self
368 .discovered_nodes
369 .iter_mut()
370 .find(|(p, a, _)| *p == peer && *a == addr)
371 {
372 *cur_expires = cmp::max(*cur_expires, expiration);
373 } else {
374 tracing::info!(%peer, address=%addr, "discovered peer on address");
375 self.discovered_nodes.push((peer, addr.clone(), expiration));
376 discovered.push((peer, addr.clone()));
377
378 self.pending_events
379 .push_back(ToSwarm::NewExternalAddrOfPeer {
380 peer_id: peer,
381 address: addr,
382 });
383 }
384 }
385
386 if !discovered.is_empty() {
387 let event = Event::Discovered(discovered);
388 self.pending_events
391 .push_front(ToSwarm::GenerateEvent(event));
392 continue;
393 }
394 let now = Instant::now();
396 let mut closest_expiration = None;
397 let mut expired = Vec::new();
398 self.discovered_nodes.retain(|(peer, addr, expiration)| {
399 if *expiration <= now {
400 tracing::info!(%peer, address=%addr, "expired peer on address");
401 expired.push((*peer, addr.clone()));
402 return false;
403 }
404 closest_expiration =
405 Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
406 true
407 });
408 if !expired.is_empty() {
409 let event = Event::Expired(expired);
410 self.pending_events.push_back(ToSwarm::GenerateEvent(event));
411 continue;
412 }
413 if let Some(closest_expiration) = closest_expiration {
414 let mut timer = P::Timer::at(closest_expiration);
415 let _ = Pin::new(&mut timer).poll_next(cx);
416
417 self.closest_expiration = Some(timer);
418 }
419
420 return Poll::Pending;
421 }
422 }
423}
424
425#[derive(Debug, Clone)]
427pub enum Event {
428 Discovered(Vec<(PeerId, Multiaddr)>),
430
431 Expired(Vec<(PeerId, Multiaddr)>),
436}