1use std::{
24 collections::{HashMap, HashSet, VecDeque},
25 convert::Infallible,
26 num::NonZeroUsize,
27 task::{Context, Poll},
28};
29
30use either::Either;
31use libp2p_core::{
32 connection::ConnectedPoint, multiaddr::Protocol, transport::PortUse, Endpoint, Multiaddr,
33};
34use libp2p_identity::PeerId;
35use libp2p_swarm::{
36 behaviour::{ConnectionClosed, DialFailure, FromSwarm},
37 dial_opts::{self, DialOpts},
38 dummy, ConnectionDenied, ConnectionHandler, ConnectionId, NetworkBehaviour,
39 NewExternalAddrCandidate, NotifyHandler, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
40};
41use lru::LruCache;
42use thiserror::Error;
43
44use crate::{handler, protocol};
45
46pub(crate) const MAX_NUMBER_OF_UPGRADE_ATTEMPTS: u8 = 3;
47
48#[derive(Debug)]
50pub struct Event {
51 pub remote_peer_id: PeerId,
52 pub result: Result<ConnectionId, Error>,
53}
54
55#[derive(Debug, Error)]
56#[error("Failed to hole-punch connection: {inner}")]
57pub struct Error {
58 inner: InnerError,
59}
60
61#[derive(Debug, Error)]
62enum InnerError {
63 #[error("Giving up after {0} dial attempts")]
64 AttemptsExceeded(u8),
65 #[error("Inbound stream error: {0}")]
66 InboundError(protocol::inbound::Error),
67 #[error("Outbound stream error: {0}")]
68 OutboundError(protocol::outbound::Error),
69}
70
71pub struct Behaviour {
72 queued_events: VecDeque<ToSwarm<Event, Either<handler::relayed::Command, Infallible>>>,
74
75 direct_connections: HashMap<PeerId, HashSet<ConnectionId>>,
77
78 address_candidates: Candidates,
79
80 direct_to_relayed_connections: HashMap<ConnectionId, ConnectionId>,
81
82 outgoing_direct_connection_attempts: HashMap<(ConnectionId, PeerId), u8>,
85}
86
87impl Behaviour {
88 pub fn new(local_peer_id: PeerId) -> Self {
89 Behaviour {
90 queued_events: Default::default(),
91 direct_connections: Default::default(),
92 address_candidates: Candidates::new(local_peer_id),
93 direct_to_relayed_connections: Default::default(),
94 outgoing_direct_connection_attempts: Default::default(),
95 }
96 }
97
98 fn observed_addresses(&self) -> Vec<Multiaddr> {
99 self.address_candidates.iter().cloned().collect()
100 }
101
102 fn on_dial_failure(
103 &mut self,
104 DialFailure {
105 peer_id,
106 connection_id: failed_direct_connection,
107 ..
108 }: DialFailure,
109 ) {
110 let Some(peer_id) = peer_id else {
111 return;
112 };
113
114 let Some(relayed_connection_id) = self
115 .direct_to_relayed_connections
116 .get(&failed_direct_connection)
117 else {
118 return;
119 };
120
121 let Some(attempt) = self
122 .outgoing_direct_connection_attempts
123 .get(&(*relayed_connection_id, peer_id))
124 else {
125 return;
126 };
127
128 if *attempt < MAX_NUMBER_OF_UPGRADE_ATTEMPTS {
129 self.queued_events.push_back(ToSwarm::NotifyHandler {
130 handler: NotifyHandler::One(*relayed_connection_id),
131 peer_id,
132 event: Either::Left(handler::relayed::Command::Connect),
133 })
134 } else {
135 self.queued_events.extend([ToSwarm::GenerateEvent(Event {
136 remote_peer_id: peer_id,
137 result: Err(Error {
138 inner: InnerError::AttemptsExceeded(MAX_NUMBER_OF_UPGRADE_ATTEMPTS),
139 }),
140 })]);
141 }
142 }
143
144 fn on_connection_closed(
145 &mut self,
146 ConnectionClosed {
147 peer_id,
148 connection_id,
149 endpoint: connected_point,
150 ..
151 }: ConnectionClosed,
152 ) {
153 if !connected_point.is_relayed() {
154 let connections = self
155 .direct_connections
156 .get_mut(&peer_id)
157 .expect("Peer of direct connection to be tracked.");
158 connections
159 .remove(&connection_id)
160 .then_some(())
161 .expect("Direct connection to be tracked.");
162 if connections.is_empty() {
163 self.direct_connections.remove(&peer_id);
164 }
165 }
166 }
167}
168
169impl NetworkBehaviour for Behaviour {
170 type ConnectionHandler = Either<handler::relayed::Handler, dummy::ConnectionHandler>;
171 type ToSwarm = Event;
172
173 fn handle_established_inbound_connection(
174 &mut self,
175 connection_id: ConnectionId,
176 peer: PeerId,
177 local_addr: &Multiaddr,
178 remote_addr: &Multiaddr,
179 ) -> Result<THandler<Self>, ConnectionDenied> {
180 if is_relayed(local_addr) {
181 let connected_point = ConnectedPoint::Listener {
182 local_addr: local_addr.clone(),
183 send_back_addr: remote_addr.clone(),
184 };
185 let mut handler =
186 handler::relayed::Handler::new(connected_point, self.observed_addresses());
187 handler.on_behaviour_event(handler::relayed::Command::Connect);
188
189 return Ok(Either::Left(handler));
191 }
192 self.direct_connections
193 .entry(peer)
194 .or_default()
195 .insert(connection_id);
196
197 assert!(
198 !self
199 .direct_to_relayed_connections
200 .contains_key(&connection_id),
201 "state mismatch"
202 );
203
204 Ok(Either::Right(dummy::ConnectionHandler))
205 }
206
207 fn handle_established_outbound_connection(
208 &mut self,
209 connection_id: ConnectionId,
210 peer: PeerId,
211 addr: &Multiaddr,
212 role_override: Endpoint,
213 port_use: PortUse,
214 ) -> Result<THandler<Self>, ConnectionDenied> {
215 if is_relayed(addr) {
216 return Ok(Either::Left(handler::relayed::Handler::new(
217 ConnectedPoint::Dialer {
218 address: addr.clone(),
219 role_override,
220 port_use,
221 },
222 self.observed_addresses(),
223 ))); }
226
227 self.direct_connections
228 .entry(peer)
229 .or_default()
230 .insert(connection_id);
231
232 if let Some(&relayed_connection_id) = self.direct_to_relayed_connections.get(&connection_id)
234 {
235 if role_override == Endpoint::Listener {
236 assert!(
237 self.outgoing_direct_connection_attempts
238 .remove(&(relayed_connection_id, peer))
239 .is_some(),
240 "state mismatch"
241 );
242 }
243
244 self.queued_events.extend([ToSwarm::GenerateEvent(Event {
245 remote_peer_id: peer,
246 result: Ok(connection_id),
247 })]);
248 }
249 Ok(Either::Right(dummy::ConnectionHandler))
250 }
251
252 fn on_connection_handler_event(
253 &mut self,
254 event_source: PeerId,
255 connection_id: ConnectionId,
256 handler_event: THandlerOutEvent<Self>,
257 ) {
258 let relayed_connection_id = match handler_event.as_ref() {
259 Either::Left(_) => connection_id,
260 Either::Right(_) => match self.direct_to_relayed_connections.get(&connection_id) {
261 None => {
262 return;
265 }
266 Some(relayed_connection_id) => *relayed_connection_id,
267 },
268 };
269
270 match handler_event {
271 Either::Left(handler::relayed::Event::InboundConnectNegotiated { remote_addrs }) => {
272 tracing::debug!(target=%event_source, addresses=?remote_addrs, "Attempting to hole-punch as dialer");
273
274 let opts = DialOpts::peer_id(event_source)
275 .addresses(remote_addrs)
276 .condition(dial_opts::PeerCondition::Always)
277 .build();
278
279 let maybe_direct_connection_id = opts.connection_id();
280
281 self.direct_to_relayed_connections
282 .insert(maybe_direct_connection_id, relayed_connection_id);
283 self.queued_events.push_back(ToSwarm::Dial { opts });
284 }
285 Either::Left(handler::relayed::Event::InboundConnectFailed { error }) => {
286 self.queued_events.push_back(ToSwarm::GenerateEvent(Event {
287 remote_peer_id: event_source,
288 result: Err(Error {
289 inner: InnerError::InboundError(error),
290 }),
291 }));
292 }
293 Either::Left(handler::relayed::Event::OutboundConnectFailed { error }) => {
294 self.queued_events.push_back(ToSwarm::GenerateEvent(Event {
295 remote_peer_id: event_source,
296 result: Err(Error {
297 inner: InnerError::OutboundError(error),
298 }),
299 }));
300
301 }
303 Either::Left(handler::relayed::Event::OutboundConnectNegotiated { remote_addrs }) => {
304 tracing::debug!(target=%event_source, addresses=?remote_addrs, "Attempting to hole-punch as listener");
305
306 let opts = DialOpts::peer_id(event_source)
307 .condition(dial_opts::PeerCondition::Always)
308 .addresses(remote_addrs)
309 .override_role()
310 .build();
311
312 let maybe_direct_connection_id = opts.connection_id();
313
314 self.direct_to_relayed_connections
315 .insert(maybe_direct_connection_id, relayed_connection_id);
316 *self
317 .outgoing_direct_connection_attempts
318 .entry((relayed_connection_id, event_source))
319 .or_default() += 1;
320 self.queued_events.push_back(ToSwarm::Dial { opts });
321 }
322 #[allow(unreachable_patterns)]
324 Either::Right(never) => libp2p_core::util::unreachable(never),
325 };
326 }
327
328 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self))]
329 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
330 if let Some(event) = self.queued_events.pop_front() {
331 return Poll::Ready(event);
332 }
333
334 Poll::Pending
335 }
336
337 fn on_swarm_event(&mut self, event: FromSwarm) {
338 match event {
339 FromSwarm::ConnectionClosed(connection_closed) => {
340 self.on_connection_closed(connection_closed)
341 }
342 FromSwarm::DialFailure(dial_failure) => self.on_dial_failure(dial_failure),
343 FromSwarm::NewExternalAddrCandidate(NewExternalAddrCandidate { addr }) => {
344 self.address_candidates.add(addr.clone());
345 }
346 _ => {}
347 }
348 }
349}
350
351struct Candidates {
359 inner: LruCache<Multiaddr, ()>,
360 me: PeerId,
361}
362
363impl Candidates {
364 fn new(me: PeerId) -> Self {
365 Self {
366 inner: LruCache::new(NonZeroUsize::new(20).expect("20 > 0")),
367 me,
368 }
369 }
370
371 fn add(&mut self, mut address: Multiaddr) {
372 if is_relayed(&address) {
373 return;
374 }
375
376 if address.iter().last() != Some(Protocol::P2p(self.me)) {
377 address.push(Protocol::P2p(self.me));
378 }
379
380 self.inner.push(address, ());
381 }
382
383 fn iter(&self) -> impl Iterator<Item = &Multiaddr> {
384 self.inner.iter().map(|(a, _)| a)
385 }
386}
387
388fn is_relayed(addr: &Multiaddr) -> bool {
389 addr.iter().any(|p| p == Protocol::P2pCircuit)
390}