1use std::{
21 collections::{HashMap, HashSet, VecDeque},
22 num::NonZeroU8,
23};
24
25use libp2p_core::{multiaddr::Protocol, Multiaddr};
26use libp2p_identity::PeerId;
27use libp2p_request_response::{
28 self as request_response, InboundFailure, InboundRequestId, ResponseChannel,
29};
30use libp2p_swarm::{
31 dial_opts::{DialOpts, PeerCondition},
32 ConnectionId, DialError, ToSwarm,
33};
34use web_time::Instant;
35
36use super::{
37 Action, AutoNatCodec, Config, DialRequest, DialResponse, Event, HandleInnerEvent, ProbeId,
38 ResponseError,
39};
40
41#[derive(Debug)]
43pub enum InboundProbeError {
44 InboundRequest(InboundFailure),
46 Response(ResponseError),
48}
49
50#[derive(Debug)]
51pub enum InboundProbeEvent {
52 Request {
54 probe_id: ProbeId,
55 peer: PeerId,
57 addresses: Vec<Multiaddr>,
59 },
60 Response {
62 probe_id: ProbeId,
63 peer: PeerId,
65 address: Multiaddr,
66 },
67 Error {
70 probe_id: ProbeId,
71 peer: PeerId,
73 error: InboundProbeError,
74 },
75}
76
77pub(crate) struct AsServer<'a> {
79 pub(crate) inner: &'a mut request_response::Behaviour<AutoNatCodec>,
80 pub(crate) config: &'a Config,
81 pub(crate) connected: &'a HashMap<PeerId, HashMap<ConnectionId, Option<Multiaddr>>>,
82 pub(crate) probe_id: &'a mut ProbeId,
83 pub(crate) throttled_clients: &'a mut Vec<(PeerId, Instant)>,
84 #[allow(clippy::type_complexity)]
85 pub(crate) ongoing_inbound: &'a mut HashMap<
86 PeerId,
87 (
88 ProbeId,
89 InboundRequestId,
90 Vec<Multiaddr>,
91 ResponseChannel<DialResponse>,
92 ),
93 >,
94}
95
96impl HandleInnerEvent for AsServer<'_> {
97 fn handle_event(
98 &mut self,
99 event: request_response::Event<DialRequest, DialResponse>,
100 ) -> VecDeque<Action> {
101 match event {
102 request_response::Event::Message {
103 peer,
104 message:
105 request_response::Message::Request {
106 request_id,
107 request,
108 channel,
109 },
110 ..
111 } => {
112 let probe_id = self.probe_id.next();
113 if !self.connected.contains_key(&peer) {
114 tracing::debug!(
115 %peer,
116 "Reject inbound dial request from peer since it is not connected"
117 );
118
119 return VecDeque::from([ToSwarm::GenerateEvent(Event::InboundProbe(
120 InboundProbeEvent::Error {
121 probe_id,
122 peer,
123 error: InboundProbeError::Response(ResponseError::DialRefused),
124 },
125 ))]);
126 }
127
128 match self.resolve_inbound_request(peer, request) {
129 Ok(addrs) => {
130 tracing::debug!(
131 %peer,
132 "Inbound dial request from peer with dial-back addresses {:?}",
133 addrs
134 );
135
136 self.ongoing_inbound
137 .insert(peer, (probe_id, request_id, addrs.clone(), channel));
138 self.throttled_clients.push((peer, Instant::now()));
139
140 VecDeque::from([
141 ToSwarm::GenerateEvent(Event::InboundProbe(
142 InboundProbeEvent::Request {
143 probe_id,
144 peer,
145 addresses: addrs.clone(),
146 },
147 )),
148 ToSwarm::Dial {
149 opts: DialOpts::peer_id(peer)
150 .condition(PeerCondition::Always)
151 .override_dial_concurrency_factor(
152 NonZeroU8::new(1).expect("1 > 0"),
153 )
154 .addresses(addrs)
155 .allocate_new_port()
156 .build(),
157 },
158 ])
159 }
160 Err((status_text, error)) => {
161 tracing::debug!(
162 %peer,
163 status=%status_text,
164 "Reject inbound dial request from peer"
165 );
166
167 let response = DialResponse {
168 result: Err(error.clone()),
169 status_text: Some(status_text),
170 };
171 let _ = self.inner.send_response(channel, response);
172
173 VecDeque::from([ToSwarm::GenerateEvent(Event::InboundProbe(
174 InboundProbeEvent::Error {
175 probe_id,
176 peer,
177 error: InboundProbeError::Response(error),
178 },
179 ))])
180 }
181 }
182 }
183 request_response::Event::InboundFailure {
184 peer,
185 error,
186 request_id,
187 ..
188 } => {
189 tracing::debug!(
190 %peer,
191 "Inbound Failure {} when on dial-back request from peer",
192 error
193 );
194
195 let probe_id = match self.ongoing_inbound.get(&peer) {
196 Some((_, rq_id, _, _)) if *rq_id == request_id => {
197 self.ongoing_inbound.remove(&peer).unwrap().0
198 }
199 _ => self.probe_id.next(),
200 };
201
202 VecDeque::from([ToSwarm::GenerateEvent(Event::InboundProbe(
203 InboundProbeEvent::Error {
204 probe_id,
205 peer,
206 error: InboundProbeError::InboundRequest(error),
207 },
208 ))])
209 }
210 _ => VecDeque::new(),
211 }
212 }
213}
214
215impl AsServer<'_> {
216 pub(crate) fn on_outbound_connection(
217 &mut self,
218 peer: &PeerId,
219 address: &Multiaddr,
220 ) -> Option<InboundProbeEvent> {
221 let (_, _, addrs, _) = self.ongoing_inbound.get(peer)?;
222
223 if !addrs.contains(address) {
225 return None;
226 }
227
228 tracing::debug!(
229 %peer,
230 %address,
231 "Dial-back to peer succeeded"
232 );
233
234 let (probe_id, _, _, channel) = self.ongoing_inbound.remove(peer).unwrap();
235 let response = DialResponse {
236 result: Ok(address.clone()),
237 status_text: None,
238 };
239 let _ = self.inner.send_response(channel, response);
240
241 Some(InboundProbeEvent::Response {
242 probe_id,
243 peer: *peer,
244 address: address.clone(),
245 })
246 }
247
248 pub(crate) fn on_outbound_dial_error(
249 &mut self,
250 peer: Option<PeerId>,
251 error: &DialError,
252 ) -> Option<InboundProbeEvent> {
253 let (probe_id, _, _, channel) = peer.and_then(|p| self.ongoing_inbound.remove(&p))?;
254
255 match peer {
256 Some(p) => tracing::debug!(
257 peer=%p,
258 "Dial-back to peer failed with error {:?}",
259 error
260 ),
261 None => tracing::debug!(
262 "Dial-back to non existent peer failed with error {:?}",
263 error
264 ),
265 };
266
267 let response_error = ResponseError::DialError;
268 let response = DialResponse {
269 result: Err(response_error.clone()),
270 status_text: Some("dial failed".to_string()),
271 };
272 let _ = self.inner.send_response(channel, response);
273
274 Some(InboundProbeEvent::Error {
275 probe_id,
276 peer: peer.expect("PeerId is present."),
277 error: InboundProbeError::Response(response_error),
278 })
279 }
280
281 fn resolve_inbound_request(
283 &mut self,
284 sender: PeerId,
285 request: DialRequest,
286 ) -> Result<Vec<Multiaddr>, (String, ResponseError)> {
287 let i = self.throttled_clients.partition_point(|(_, time)| {
289 *time + self.config.throttle_clients_period < Instant::now()
290 });
291 self.throttled_clients.drain(..i);
292
293 if request.peer_id != sender {
294 let status_text = "peer id mismatch".to_string();
295 return Err((status_text, ResponseError::BadRequest));
296 }
297
298 if self.ongoing_inbound.contains_key(&sender) {
299 let status_text = "dial-back already ongoing".to_string();
300 return Err((status_text, ResponseError::DialRefused));
301 }
302
303 if self.throttled_clients.len() >= self.config.throttle_clients_global_max {
304 let status_text = "too many total dials".to_string();
305 return Err((status_text, ResponseError::DialRefused));
306 }
307
308 let throttled_for_client = self
309 .throttled_clients
310 .iter()
311 .filter(|(p, _)| p == &sender)
312 .count();
313
314 if throttled_for_client >= self.config.throttle_clients_peer_max {
315 let status_text = "too many dials for peer".to_string();
316 return Err((status_text, ResponseError::DialRefused));
317 }
318
319 let observed_addr = self
321 .connected
322 .get(&sender)
323 .expect("Peer is connected.")
324 .values()
325 .find_map(|a| a.as_ref())
326 .ok_or_else(|| {
327 let status_text = "refusing to dial peer with blocked observed address".to_string();
328 (status_text, ResponseError::DialRefused)
329 })?;
330
331 let mut addrs = Self::filter_valid_addrs(sender, request.addresses, observed_addr);
332 addrs.truncate(self.config.max_peer_addresses);
333
334 if addrs.is_empty() {
335 let status_text = "no dialable addresses".to_string();
336 return Err((status_text, ResponseError::DialRefused));
337 }
338
339 Ok(addrs)
340 }
341
342 fn filter_valid_addrs(
344 peer: PeerId,
345 demanded: Vec<Multiaddr>,
346 observed_remote_at: &Multiaddr,
347 ) -> Vec<Multiaddr> {
348 let Some(observed_ip) = observed_remote_at
349 .into_iter()
350 .find(|p| matches!(p, Protocol::Ip4(_) | Protocol::Ip6(_)))
351 else {
352 return Vec::new();
353 };
354
355 let mut distinct = HashSet::new();
356 demanded
357 .into_iter()
358 .filter_map(|addr| {
359 let i = addr
361 .iter()
362 .position(|p| matches!(p, Protocol::Ip4(_) | Protocol::Ip6(_)))?;
363 let mut addr = addr.replace(i, |_| Some(observed_ip.clone()))?;
364
365 let is_valid = addr.iter().all(|proto| match proto {
366 Protocol::P2pCircuit => false,
367 Protocol::P2p(peer_id) => peer_id == peer,
368 _ => true,
369 });
370
371 if !is_valid {
372 return None;
373 }
374 if !addr.iter().any(|p| matches!(p, Protocol::P2p(_))) {
375 addr.push(Protocol::P2p(peer))
376 }
377 distinct.insert(addr.clone()).then_some(addr)
379 })
380 .collect()
381 }
382}
383
384#[cfg(test)]
385mod test {
386 use std::net::Ipv4Addr;
387
388 use super::*;
389
390 fn random_ip<'a>() -> Protocol<'a> {
391 Protocol::Ip4(Ipv4Addr::new(
392 rand::random(),
393 rand::random(),
394 rand::random(),
395 rand::random(),
396 ))
397 }
398 fn random_port<'a>() -> Protocol<'a> {
399 Protocol::Tcp(rand::random())
400 }
401
402 #[test]
403 fn filter_addresses() {
404 let peer_id = PeerId::random();
405 let observed_ip = random_ip();
406 let observed_addr = Multiaddr::empty()
407 .with(observed_ip.clone())
408 .with(random_port())
409 .with(Protocol::P2p(peer_id));
410 let demanded_1 = Multiaddr::empty()
412 .with(random_ip())
413 .with(random_port())
414 .with(Protocol::P2p(peer_id));
415 let demanded_2 = Multiaddr::empty()
417 .with(random_ip())
418 .with(random_port())
419 .with(Protocol::P2p(PeerId::random()));
420 let demanded_3 = Multiaddr::empty().with(random_ip()).with(random_port());
422 let demanded_4 = Multiaddr::empty()
424 .with(random_ip())
425 .with(random_port())
426 .with(Protocol::P2p(PeerId::random()))
427 .with(Protocol::P2pCircuit)
428 .with(Protocol::P2p(peer_id));
429 let demanded = vec![
430 demanded_1.clone(),
431 demanded_2,
432 demanded_3.clone(),
433 demanded_4,
434 ];
435 let filtered = AsServer::filter_valid_addrs(peer_id, demanded, &observed_addr);
436 let expected_1 = demanded_1
437 .replace(0, |_| Some(observed_ip.clone()))
438 .unwrap();
439 let expected_2 = demanded_3
440 .replace(0, |_| Some(observed_ip))
441 .unwrap()
442 .with(Protocol::P2p(peer_id));
443 assert_eq!(filtered, vec![expected_1, expected_2]);
444 }
445}