libp2p_quic/
transport.rs

1// Copyright 2017-2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    collections::{
23        hash_map::{DefaultHasher, Entry},
24        HashMap, HashSet,
25    },
26    fmt,
27    hash::{Hash, Hasher},
28    io,
29    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
30    pin::Pin,
31    task::{Context, Poll, Waker},
32    time::Duration,
33};
34
35use futures::{
36    channel::oneshot,
37    future::{BoxFuture, Either},
38    prelude::*,
39    ready,
40    stream::{SelectAll, StreamExt},
41};
42use if_watch::IfEvent;
43use libp2p_core::{
44    multiaddr::{Multiaddr, Protocol},
45    transport::{DialOpts, ListenerId, PortUse, TransportError, TransportEvent},
46    Endpoint, Transport,
47};
48use libp2p_identity::PeerId;
49use socket2::{Domain, Socket, Type};
50
51use crate::{
52    config::{Config, QuinnConfig},
53    hole_punching::hole_puncher,
54    provider::Provider,
55    ConnectError, Connecting, Connection, Error,
56};
57
58/// Implementation of the [`Transport`] trait for QUIC.
59///
60/// By default only QUIC Version 1 (RFC 9000) is supported. In the [`Multiaddr`] this maps to
61/// [`libp2p_core::multiaddr::Protocol::QuicV1`].
62/// The [`libp2p_core::multiaddr::Protocol::Quic`] codepoint is interpreted as QUIC version
63/// draft-29 and only supported if [`Config::support_draft_29`] is set to `true`.
64/// Note that in that case servers support both version an all QUIC listening addresses.
65///
66/// Version draft-29 should only be used to connect to nodes from other libp2p implementations
67/// that do not support `QuicV1` yet. Support for it will be removed long-term.
68/// See <https://github.com/multiformats/multiaddr/issues/145>.
69#[derive(Debug)]
70pub struct GenTransport<P: Provider> {
71    /// Config for the inner [`quinn`] structs.
72    quinn_config: QuinnConfig,
73    /// Timeout for the [`Connecting`] future.
74    handshake_timeout: Duration,
75    /// Whether draft-29 is supported for dialing and listening.
76    support_draft_29: bool,
77    /// Streams of active [`Listener`]s.
78    listeners: SelectAll<Listener<P>>,
79    /// Dialer for each socket family if no matching listener exists.
80    dialer: HashMap<SocketFamily, quinn::Endpoint>,
81    /// Waker to poll the transport again when a new dialer or listener is added.
82    waker: Option<Waker>,
83    /// Holepunching attempts
84    hole_punch_attempts: HashMap<SocketAddr, oneshot::Sender<Connecting>>,
85}
86
87impl<P: Provider> GenTransport<P> {
88    /// Create a new [`GenTransport`] with the given [`Config`].
89    pub fn new(config: Config) -> Self {
90        let handshake_timeout = config.handshake_timeout;
91        let support_draft_29 = config.support_draft_29;
92        let quinn_config = config.into();
93        Self {
94            listeners: SelectAll::new(),
95            quinn_config,
96            handshake_timeout,
97            dialer: HashMap::new(),
98            waker: None,
99            support_draft_29,
100            hole_punch_attempts: Default::default(),
101        }
102    }
103
104    /// Create a new [`quinn::Endpoint`] with the given configs.
105    fn new_endpoint(
106        endpoint_config: quinn::EndpointConfig,
107        server_config: Option<quinn::ServerConfig>,
108        socket: UdpSocket,
109    ) -> Result<quinn::Endpoint, Error> {
110        use crate::provider::Runtime;
111        match P::runtime() {
112            #[cfg(feature = "tokio")]
113            Runtime::Tokio => {
114                let runtime = std::sync::Arc::new(quinn::TokioRuntime);
115                let endpoint =
116                    quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
117                Ok(endpoint)
118            }
119            #[cfg(feature = "async-std")]
120            Runtime::AsyncStd => {
121                let runtime = std::sync::Arc::new(quinn::AsyncStdRuntime);
122                let endpoint =
123                    quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
124                Ok(endpoint)
125            }
126            Runtime::Dummy => {
127                let _ = endpoint_config;
128                let _ = server_config;
129                let _ = socket;
130                let err = std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found");
131                Err(Error::Io(err))
132            }
133        }
134    }
135
136    /// Extract the addr, quic version and peer id from the given [`Multiaddr`].
137    fn remote_multiaddr_to_socketaddr(
138        &self,
139        addr: Multiaddr,
140        check_unspecified_addr: bool,
141    ) -> Result<
142        (SocketAddr, ProtocolVersion, Option<PeerId>),
143        TransportError<<Self as Transport>::Error>,
144    > {
145        let (socket_addr, version, peer_id) = multiaddr_to_socketaddr(&addr, self.support_draft_29)
146            .ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?;
147        if check_unspecified_addr && (socket_addr.port() == 0 || socket_addr.ip().is_unspecified())
148        {
149            return Err(TransportError::MultiaddrNotSupported(addr));
150        }
151        Ok((socket_addr, version, peer_id))
152    }
153
154    /// Pick any listener to use for dialing.
155    fn eligible_listener(&mut self, socket_addr: &SocketAddr) -> Option<&mut Listener<P>> {
156        let mut listeners: Vec<_> = self
157            .listeners
158            .iter_mut()
159            .filter(|l| {
160                if l.is_closed {
161                    return false;
162                }
163                SocketFamily::is_same(&l.socket_addr().ip(), &socket_addr.ip())
164            })
165            .filter(|l| {
166                if socket_addr.ip().is_loopback() {
167                    l.listening_addresses
168                        .iter()
169                        .any(|ip_addr| ip_addr.is_loopback())
170                } else {
171                    true
172                }
173            })
174            .collect();
175        match listeners.len() {
176            0 => None,
177            1 => listeners.pop(),
178            _ => {
179                // Pick any listener to use for dialing.
180                // We hash the socket address to achieve determinism.
181                let mut hasher = DefaultHasher::new();
182                socket_addr.hash(&mut hasher);
183                let index = hasher.finish() as usize % listeners.len();
184                Some(listeners.swap_remove(index))
185            }
186        }
187    }
188
189    fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<UdpSocket> {
190        let socket = Socket::new(
191            Domain::for_address(socket_addr),
192            Type::DGRAM,
193            Some(socket2::Protocol::UDP),
194        )?;
195        if socket_addr.is_ipv6() {
196            socket.set_only_v6(true)?;
197        }
198
199        socket.bind(&socket_addr.into())?;
200
201        Ok(socket.into())
202    }
203
204    fn bound_socket(&mut self, socket_addr: SocketAddr) -> Result<quinn::Endpoint, Error> {
205        let socket_family = socket_addr.ip().into();
206        if let Some(waker) = self.waker.take() {
207            waker.wake();
208        }
209        let listen_socket_addr = match socket_family {
210            SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
211            SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
212        };
213        let socket = UdpSocket::bind(listen_socket_addr)?;
214        let endpoint_config = self.quinn_config.endpoint_config.clone();
215        let endpoint = Self::new_endpoint(endpoint_config, None, socket)?;
216        Ok(endpoint)
217    }
218}
219
220impl<P: Provider> Transport for GenTransport<P> {
221    type Output = (PeerId, Connection);
222    type Error = Error;
223    type ListenerUpgrade = Connecting;
224    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
225
226    fn listen_on(
227        &mut self,
228        listener_id: ListenerId,
229        addr: Multiaddr,
230    ) -> Result<(), TransportError<Self::Error>> {
231        let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, false)?;
232        let endpoint_config = self.quinn_config.endpoint_config.clone();
233        let server_config = self.quinn_config.server_config.clone();
234        let socket = self.create_socket(socket_addr).map_err(Self::Error::from)?;
235
236        let socket_c = socket.try_clone().map_err(Self::Error::from)?;
237        let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?;
238        let listener = Listener::new(
239            listener_id,
240            socket_c,
241            endpoint,
242            self.handshake_timeout,
243            version,
244        )?;
245        self.listeners.push(listener);
246
247        if let Some(waker) = self.waker.take() {
248            waker.wake();
249        }
250
251        // Remove dialer endpoint so that the endpoint is dropped once the last
252        // connection that uses it is closed.
253        // New outbound connections will use the bidirectional (listener) endpoint.
254        self.dialer.remove(&socket_addr.ip().into());
255
256        Ok(())
257    }
258
259    fn remove_listener(&mut self, id: ListenerId) -> bool {
260        if let Some(listener) = self.listeners.iter_mut().find(|l| l.listener_id == id) {
261            // Close the listener, which will eventually finish its stream.
262            // `SelectAll` removes streams once they are finished.
263            listener.close(Ok(()));
264            true
265        } else {
266            false
267        }
268    }
269
270    fn dial(
271        &mut self,
272        addr: Multiaddr,
273        dial_opts: DialOpts,
274    ) -> Result<Self::Dial, TransportError<Self::Error>> {
275        let (socket_addr, version, peer_id) =
276            self.remote_multiaddr_to_socketaddr(addr.clone(), true)?;
277
278        match (dial_opts.role, dial_opts.port_use) {
279            (Endpoint::Dialer, _) | (Endpoint::Listener, PortUse::Reuse) => {
280                let endpoint = if let Some(listener) = dial_opts
281                    .port_use
282                    .eq(&PortUse::Reuse)
283                    .then(|| self.eligible_listener(&socket_addr))
284                    .flatten()
285                {
286                    listener.endpoint.clone()
287                } else {
288                    let socket_family = socket_addr.ip().into();
289                    let dialer = if dial_opts.port_use == PortUse::Reuse {
290                        if let Some(occupied) = self.dialer.get(&socket_family) {
291                            occupied.clone()
292                        } else {
293                            let endpoint = self.bound_socket(socket_addr)?;
294                            self.dialer.insert(socket_family, endpoint.clone());
295                            endpoint
296                        }
297                    } else {
298                        self.bound_socket(socket_addr)?
299                    };
300                    dialer
301                };
302                let handshake_timeout = self.handshake_timeout;
303                let mut client_config = self.quinn_config.client_config.clone();
304                if version == ProtocolVersion::Draft29 {
305                    client_config.version(0xff00_001d);
306                }
307                Ok(Box::pin(async move {
308                    // This `"l"` seems necessary because an empty string is an invalid domain
309                    // name. While we don't use domain names, the underlying rustls library
310                    // is based upon the assumption that we do.
311                    let connecting = endpoint
312                        .connect_with(client_config, socket_addr, "l")
313                        .map_err(ConnectError)?;
314                    Connecting::new(connecting, handshake_timeout).await
315                }))
316            }
317            (Endpoint::Listener, _) => {
318                let peer_id = peer_id.ok_or(TransportError::MultiaddrNotSupported(addr.clone()))?;
319
320                let socket = self
321                    .eligible_listener(&socket_addr)
322                    .ok_or(TransportError::Other(
323                        Error::NoActiveListenerForDialAsListener,
324                    ))?
325                    .try_clone_socket()
326                    .map_err(Self::Error::from)?;
327
328                tracing::debug!("Preparing for hole-punch from {addr}");
329
330                let hole_puncher = hole_puncher::<P>(socket, socket_addr, self.handshake_timeout);
331
332                let (sender, receiver) = oneshot::channel();
333
334                match self.hole_punch_attempts.entry(socket_addr) {
335                    Entry::Occupied(mut sender_entry) => {
336                        // Stale senders, i.e. from failed hole punches are not removed.
337                        // Thus, we can just overwrite a stale sender.
338                        if !sender_entry.get().is_canceled() {
339                            return Err(TransportError::Other(Error::HolePunchInProgress(
340                                socket_addr,
341                            )));
342                        }
343                        sender_entry.insert(sender);
344                    }
345                    Entry::Vacant(entry) => {
346                        entry.insert(sender);
347                    }
348                };
349
350                Ok(Box::pin(async move {
351                    futures::pin_mut!(hole_puncher);
352                    match futures::future::select(receiver, hole_puncher).await {
353                        Either::Left((message, _)) => {
354                            let (inbound_peer_id, connection) = message
355                                .expect(
356                                    "hole punch connection sender is never dropped before receiver",
357                                )
358                                .await?;
359                            if inbound_peer_id != peer_id {
360                                tracing::warn!(
361                                    peer=%peer_id,
362                                    inbound_peer=%inbound_peer_id,
363                                    socket_address=%socket_addr,
364                                    "expected inbound connection from socket_address to resolve to peer but got inbound peer"
365                                );
366                            }
367                            Ok((inbound_peer_id, connection))
368                        }
369                        Either::Right((hole_punch_err, _)) => Err(hole_punch_err),
370                    }
371                }))
372            }
373        }
374    }
375
376    fn poll(
377        mut self: Pin<&mut Self>,
378        cx: &mut Context<'_>,
379    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
380        while let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
381            match ev {
382                TransportEvent::Incoming {
383                    listener_id,
384                    mut upgrade,
385                    local_addr,
386                    send_back_addr,
387                } => {
388                    let socket_addr =
389                        multiaddr_to_socketaddr(&send_back_addr, self.support_draft_29)
390                            .unwrap()
391                            .0;
392
393                    if let Some(sender) = self.hole_punch_attempts.remove(&socket_addr) {
394                        match sender.send(upgrade) {
395                            Ok(()) => continue,
396                            Err(timed_out_holepunch) => {
397                                upgrade = timed_out_holepunch;
398                            }
399                        }
400                    }
401
402                    return Poll::Ready(TransportEvent::Incoming {
403                        listener_id,
404                        upgrade,
405                        local_addr,
406                        send_back_addr,
407                    });
408                }
409                _ => return Poll::Ready(ev),
410            }
411        }
412
413        self.waker = Some(cx.waker().clone());
414        Poll::Pending
415    }
416}
417
418impl From<Error> for TransportError<Error> {
419    fn from(err: Error) -> Self {
420        TransportError::Other(err)
421    }
422}
423
424/// Listener for incoming connections.
425struct Listener<P: Provider> {
426    /// Id of the listener.
427    listener_id: ListenerId,
428
429    /// Version of the supported quic protocol.
430    version: ProtocolVersion,
431
432    /// Endpoint
433    endpoint: quinn::Endpoint,
434
435    /// An underlying copy of the socket to be able to hole punch with
436    socket: UdpSocket,
437
438    /// A future to poll new incoming connections.
439    accept: BoxFuture<'static, Option<quinn::Incoming>>,
440    /// Timeout for connection establishment on inbound connections.
441    handshake_timeout: Duration,
442
443    /// Watcher for network interface changes.
444    ///
445    /// None if we are only listening on a single interface.
446    if_watcher: Option<P::IfWatcher>,
447
448    /// Whether the listener was closed and the stream should terminate.
449    is_closed: bool,
450
451    /// Pending event to reported.
452    pending_event: Option<<Self as Stream>::Item>,
453
454    /// The stream must be awaken after it has been closed to deliver the last event.
455    close_listener_waker: Option<Waker>,
456
457    listening_addresses: HashSet<IpAddr>,
458}
459
460impl<P: Provider> Listener<P> {
461    fn new(
462        listener_id: ListenerId,
463        socket: UdpSocket,
464        endpoint: quinn::Endpoint,
465        handshake_timeout: Duration,
466        version: ProtocolVersion,
467    ) -> Result<Self, Error> {
468        let if_watcher;
469        let pending_event;
470        let mut listening_addresses = HashSet::new();
471        let local_addr = socket.local_addr()?;
472        if local_addr.ip().is_unspecified() {
473            if_watcher = Some(P::new_if_watcher()?);
474            pending_event = None;
475        } else {
476            if_watcher = None;
477            listening_addresses.insert(local_addr.ip());
478            let ma = socketaddr_to_multiaddr(&local_addr, version);
479            pending_event = Some(TransportEvent::NewAddress {
480                listener_id,
481                listen_addr: ma,
482            })
483        }
484
485        let endpoint_c = endpoint.clone();
486        let accept = async move { endpoint_c.accept().await }.boxed();
487
488        Ok(Listener {
489            endpoint,
490            socket,
491            accept,
492            listener_id,
493            version,
494            handshake_timeout,
495            if_watcher,
496            is_closed: false,
497            pending_event,
498            close_listener_waker: None,
499            listening_addresses,
500        })
501    }
502
503    /// Report the listener as closed in a [`TransportEvent::ListenerClosed`] and
504    /// terminate the stream.
505    fn close(&mut self, reason: Result<(), Error>) {
506        if self.is_closed {
507            return;
508        }
509        self.endpoint.close(From::from(0u32), &[]);
510        self.pending_event = Some(TransportEvent::ListenerClosed {
511            listener_id: self.listener_id,
512            reason,
513        });
514        self.is_closed = true;
515
516        // Wake the stream to deliver the last event.
517        if let Some(waker) = self.close_listener_waker.take() {
518            waker.wake();
519        }
520    }
521
522    /// Clone underlying socket (for hole punching).
523    fn try_clone_socket(&self) -> std::io::Result<UdpSocket> {
524        self.socket.try_clone()
525    }
526
527    fn socket_addr(&self) -> SocketAddr {
528        self.socket
529            .local_addr()
530            .expect("Cannot fail because the socket is bound")
531    }
532
533    /// Poll for a next If Event.
534    fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
535        let endpoint_addr = self.socket_addr();
536        let Some(if_watcher) = self.if_watcher.as_mut() else {
537            return Poll::Pending;
538        };
539        loop {
540            match ready!(P::poll_if_event(if_watcher, cx)) {
541                Ok(IfEvent::Up(inet)) => {
542                    if let Some(listen_addr) =
543                        ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
544                    {
545                        tracing::debug!(
546                            address=%listen_addr,
547                            "New listen address"
548                        );
549                        self.listening_addresses.insert(inet.addr());
550                        return Poll::Ready(TransportEvent::NewAddress {
551                            listener_id: self.listener_id,
552                            listen_addr,
553                        });
554                    }
555                }
556                Ok(IfEvent::Down(inet)) => {
557                    if let Some(listen_addr) =
558                        ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
559                    {
560                        tracing::debug!(
561                            address=%listen_addr,
562                            "Expired listen address"
563                        );
564                        self.listening_addresses.remove(&inet.addr());
565                        return Poll::Ready(TransportEvent::AddressExpired {
566                            listener_id: self.listener_id,
567                            listen_addr,
568                        });
569                    }
570                }
571                Err(err) => {
572                    return Poll::Ready(TransportEvent::ListenerError {
573                        listener_id: self.listener_id,
574                        error: err.into(),
575                    })
576                }
577            }
578        }
579    }
580}
581
582impl<P: Provider> Stream for Listener<P> {
583    type Item = TransportEvent<<GenTransport<P> as Transport>::ListenerUpgrade, Error>;
584    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
585        loop {
586            if let Some(event) = self.pending_event.take() {
587                return Poll::Ready(Some(event));
588            }
589            if self.is_closed {
590                return Poll::Ready(None);
591            }
592            if let Poll::Ready(event) = self.poll_if_addr(cx) {
593                return Poll::Ready(Some(event));
594            }
595
596            match self.accept.poll_unpin(cx) {
597                Poll::Ready(Some(incoming)) => {
598                    let endpoint = self.endpoint.clone();
599                    self.accept = async move { endpoint.accept().await }.boxed();
600
601                    let connecting = match incoming.accept() {
602                        Ok(connecting) => connecting,
603                        Err(error) => {
604                            return Poll::Ready(Some(TransportEvent::ListenerError {
605                                listener_id: self.listener_id,
606                                error: Error::Connection(crate::ConnectionError(error)),
607                            }))
608                        }
609                    };
610
611                    let local_addr = socketaddr_to_multiaddr(&self.socket_addr(), self.version);
612                    let remote_addr = connecting.remote_address();
613                    let send_back_addr = socketaddr_to_multiaddr(&remote_addr, self.version);
614
615                    let event = TransportEvent::Incoming {
616                        upgrade: Connecting::new(connecting, self.handshake_timeout),
617                        local_addr,
618                        send_back_addr,
619                        listener_id: self.listener_id,
620                    };
621                    return Poll::Ready(Some(event));
622                }
623                Poll::Ready(None) => {
624                    self.close(Ok(()));
625                    continue;
626                }
627                Poll::Pending => {}
628            };
629
630            self.close_listener_waker = Some(cx.waker().clone());
631
632            return Poll::Pending;
633        }
634    }
635}
636
637impl<P: Provider> fmt::Debug for Listener<P> {
638    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
639        f.debug_struct("Listener")
640            .field("listener_id", &self.listener_id)
641            .field("handshake_timeout", &self.handshake_timeout)
642            .field("is_closed", &self.is_closed)
643            .field("pending_event", &self.pending_event)
644            .finish()
645    }
646}
647
648#[derive(Debug, Clone, Copy, PartialEq, Eq)]
649pub(crate) enum ProtocolVersion {
650    V1, // i.e. RFC9000
651    Draft29,
652}
653
654#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
655pub(crate) enum SocketFamily {
656    Ipv4,
657    Ipv6,
658}
659
660impl SocketFamily {
661    fn is_same(a: &IpAddr, b: &IpAddr) -> bool {
662        matches!(
663            (a, b),
664            (IpAddr::V4(_), IpAddr::V4(_)) | (IpAddr::V6(_), IpAddr::V6(_))
665        )
666    }
667}
668
669impl From<IpAddr> for SocketFamily {
670    fn from(ip: IpAddr) -> Self {
671        match ip {
672            IpAddr::V4(_) => SocketFamily::Ipv4,
673            IpAddr::V6(_) => SocketFamily::Ipv6,
674        }
675    }
676}
677
678/// Turn an [`IpAddr`] reported by the interface watcher into a
679/// listen-address for the endpoint.
680///
681/// For this, the `ip` is combined with the port that the endpoint
682/// is actually bound.
683///
684/// Returns `None` if the `ip` is not the same socket family as the
685/// address that the endpoint is bound to.
686fn ip_to_listenaddr(
687    endpoint_addr: &SocketAddr,
688    ip: IpAddr,
689    version: ProtocolVersion,
690) -> Option<Multiaddr> {
691    // True if either both addresses are Ipv4 or both Ipv6.
692    if !SocketFamily::is_same(&endpoint_addr.ip(), &ip) {
693        return None;
694    }
695    let socket_addr = SocketAddr::new(ip, endpoint_addr.port());
696    Some(socketaddr_to_multiaddr(&socket_addr, version))
697}
698
699/// Tries to turn a QUIC multiaddress into a UDP [`SocketAddr`]. Returns None if the format
700/// of the multiaddr is wrong.
701fn multiaddr_to_socketaddr(
702    addr: &Multiaddr,
703    support_draft_29: bool,
704) -> Option<(SocketAddr, ProtocolVersion, Option<PeerId>)> {
705    let mut iter = addr.iter();
706    let proto1 = iter.next()?;
707    let proto2 = iter.next()?;
708    let proto3 = iter.next()?;
709
710    let mut peer_id = None;
711    for proto in iter {
712        match proto {
713            Protocol::P2p(id) => {
714                peer_id = Some(id);
715            }
716            _ => return None,
717        }
718    }
719    let version = match proto3 {
720        Protocol::QuicV1 => ProtocolVersion::V1,
721        Protocol::Quic if support_draft_29 => ProtocolVersion::Draft29,
722        _ => return None,
723    };
724
725    match (proto1, proto2) {
726        (Protocol::Ip4(ip), Protocol::Udp(port)) => {
727            Some((SocketAddr::new(ip.into(), port), version, peer_id))
728        }
729        (Protocol::Ip6(ip), Protocol::Udp(port)) => {
730            Some((SocketAddr::new(ip.into(), port), version, peer_id))
731        }
732        _ => None,
733    }
734}
735
736/// Turns an IP address and port into the corresponding QUIC multiaddr.
737fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -> Multiaddr {
738    let quic_proto = match version {
739        ProtocolVersion::V1 => Protocol::QuicV1,
740        ProtocolVersion::Draft29 => Protocol::Quic,
741    };
742    Multiaddr::empty()
743        .with(socket_addr.ip().into())
744        .with(Protocol::Udp(socket_addr.port()))
745        .with(quic_proto)
746}
747
748#[cfg(test)]
749#[cfg(any(feature = "async-std", feature = "tokio"))]
750mod tests {
751    use futures::future::poll_fn;
752
753    use super::*;
754
755    #[test]
756    fn multiaddr_to_udp_conversion() {
757        assert!(multiaddr_to_socketaddr(
758            &"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap(),
759            true
760        )
761        .is_none());
762
763        assert_eq!(
764            multiaddr_to_socketaddr(
765                &"/ip4/127.0.0.1/udp/12345/quic-v1"
766                    .parse::<Multiaddr>()
767                    .unwrap(),
768                false
769            ),
770            Some((
771                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345,),
772                ProtocolVersion::V1,
773                None
774            ))
775        );
776        assert_eq!(
777            multiaddr_to_socketaddr(
778                &"/ip4/255.255.255.255/udp/8080/quic-v1"
779                    .parse::<Multiaddr>()
780                    .unwrap(),
781                false
782            ),
783            Some((
784                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)), 8080,),
785                ProtocolVersion::V1,
786                None
787            ))
788        );
789        assert_eq!(
790            multiaddr_to_socketaddr(
791                &"/ip4/127.0.0.1/udp/55148/quic-v1/p2p/12D3KooW9xk7Zp1gejwfwNpfm6L9zH5NL4Bx5rm94LRYJJHJuARZ"
792                    .parse::<Multiaddr>()
793                    .unwrap(), false
794            ),
795            Some((SocketAddr::new(
796                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
797                55148,
798            ), ProtocolVersion::V1, Some("12D3KooW9xk7Zp1gejwfwNpfm6L9zH5NL4Bx5rm94LRYJJHJuARZ".parse().unwrap())))
799        );
800        assert_eq!(
801            multiaddr_to_socketaddr(
802                &"/ip6/::1/udp/12345/quic-v1".parse::<Multiaddr>().unwrap(),
803                false
804            ),
805            Some((
806                SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 12345,),
807                ProtocolVersion::V1,
808                None
809            ))
810        );
811        assert_eq!(
812            multiaddr_to_socketaddr(
813                &"/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/udp/8080/quic-v1"
814                    .parse::<Multiaddr>()
815                    .unwrap(),
816                false
817            ),
818            Some((
819                SocketAddr::new(
820                    IpAddr::V6(Ipv6Addr::new(
821                        65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535,
822                    )),
823                    8080,
824                ),
825                ProtocolVersion::V1,
826                None
827            ))
828        );
829
830        assert!(multiaddr_to_socketaddr(
831            &"/ip4/127.0.0.1/udp/1234/quic".parse::<Multiaddr>().unwrap(),
832            false
833        )
834        .is_none());
835        assert_eq!(
836            multiaddr_to_socketaddr(
837                &"/ip4/127.0.0.1/udp/1234/quic".parse::<Multiaddr>().unwrap(),
838                true
839            ),
840            Some((
841                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 1234,),
842                ProtocolVersion::Draft29,
843                None
844            ))
845        );
846    }
847
848    #[cfg(feature = "async-std")]
849    #[async_std::test]
850    async fn test_close_listener() {
851        let keypair = libp2p_identity::Keypair::generate_ed25519();
852        let config = Config::new(&keypair);
853        let mut transport = crate::async_std::Transport::new(config);
854        assert!(poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx))
855            .now_or_never()
856            .is_none());
857
858        // Run test twice to check that there is no unexpected behaviour if `Transport.listener`
859        // is temporarily empty.
860        for _ in 0..2 {
861            let id = ListenerId::next();
862            transport
863                .listen_on(id, "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap())
864                .unwrap();
865
866            match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
867                TransportEvent::NewAddress {
868                    listener_id,
869                    listen_addr,
870                } => {
871                    assert_eq!(listener_id, id);
872                    assert!(
873                        matches!(listen_addr.iter().next(), Some(Protocol::Ip4(a)) if !a.is_unspecified())
874                    );
875                    assert!(
876                        matches!(listen_addr.iter().nth(1), Some(Protocol::Udp(port)) if port != 0)
877                    );
878                    assert!(matches!(listen_addr.iter().nth(2), Some(Protocol::QuicV1)));
879                }
880                e => panic!("Unexpected event: {e:?}"),
881            }
882            assert!(transport.remove_listener(id), "Expect listener to exist.");
883            match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
884                TransportEvent::ListenerClosed {
885                    listener_id,
886                    reason: Ok(()),
887                } => {
888                    assert_eq!(listener_id, id);
889                }
890                e => panic!("Unexpected event: {e:?}"),
891            }
892            // Poll once again so that the listener has the chance to return `Poll::Ready(None)` and
893            // be removed from the list of listeners.
894            assert!(poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx))
895                .now_or_never()
896                .is_none());
897            assert!(transport.listeners.is_empty());
898        }
899    }
900
901    #[cfg(feature = "tokio")]
902    #[tokio::test]
903    async fn test_dialer_drop() {
904        let keypair = libp2p_identity::Keypair::generate_ed25519();
905        let config = Config::new(&keypair);
906        let mut transport = crate::tokio::Transport::new(config);
907
908        let _dial = transport
909            .dial(
910                "/ip4/123.45.67.8/udp/1234/quic-v1".parse().unwrap(),
911                DialOpts {
912                    role: Endpoint::Dialer,
913                    port_use: PortUse::Reuse,
914                },
915            )
916            .unwrap();
917
918        assert!(transport.dialer.contains_key(&SocketFamily::Ipv4));
919        assert!(!transport.dialer.contains_key(&SocketFamily::Ipv6));
920
921        // Start listening so that the dialer and driver are dropped.
922        transport
923            .listen_on(
924                ListenerId::next(),
925                "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap(),
926            )
927            .unwrap();
928        assert!(!transport.dialer.contains_key(&SocketFamily::Ipv4));
929    }
930
931    #[cfg(feature = "tokio")]
932    #[tokio::test]
933    async fn test_listens_ipv4_ipv6_separately() {
934        let keypair = libp2p_identity::Keypair::generate_ed25519();
935        let config = Config::new(&keypair);
936        let mut transport = crate::tokio::Transport::new(config);
937        let port = {
938            let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
939            socket.local_addr().unwrap().port()
940        };
941
942        transport
943            .listen_on(
944                ListenerId::next(),
945                format!("/ip4/0.0.0.0/udp/{port}/quic-v1").parse().unwrap(),
946            )
947            .unwrap();
948        transport
949            .listen_on(
950                ListenerId::next(),
951                format!("/ip6/::/udp/{port}/quic-v1").parse().unwrap(),
952            )
953            .unwrap();
954    }
955}