1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26use std::{
27 collections::{HashMap, HashSet},
28 fmt,
29 fmt::{Display, Formatter},
30 future::Future,
31 io, mem,
32 pin::Pin,
33 sync::atomic::{AtomicUsize, Ordering},
34 task::{Context, Poll, Waker},
35 time::Duration,
36};
37
38pub use error::ConnectionError;
39pub(crate) use error::{
40 PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
41};
42use futures::{future::BoxFuture, stream, stream::FuturesUnordered, FutureExt, StreamExt};
43use futures_timer::Delay;
44use libp2p_core::{
45 connection::ConnectedPoint,
46 multiaddr::Multiaddr,
47 muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox},
48 transport::PortUse,
49 upgrade,
50 upgrade::{NegotiationError, ProtocolError},
51 Endpoint,
52};
53use libp2p_identity::PeerId;
54pub use supported_protocols::SupportedProtocols;
55use web_time::Instant;
56
57use crate::{
58 handler::{
59 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError,
60 FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport,
61 ProtocolsChange, UpgradeInfoSend,
62 },
63 stream::ActiveStreamCounter,
64 upgrade::{InboundUpgradeSend, OutboundUpgradeSend},
65 ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
66};
67
68static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
69
70#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
72pub struct ConnectionId(usize);
73
74impl ConnectionId {
75 pub fn new_unchecked(id: usize) -> Self {
83 Self(id)
84 }
85
86 pub(crate) fn next() -> Self {
88 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
89 }
90}
91
92impl Display for ConnectionId {
93 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
94 write!(f, "{}", self.0)
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
100pub(crate) struct Connected {
101 pub(crate) endpoint: ConnectedPoint,
103 pub(crate) peer_id: PeerId,
105}
106
107#[derive(Debug, Clone)]
109pub(crate) enum Event<T> {
110 Handler(T),
112 AddressChange(Multiaddr),
114}
115
116pub(crate) struct Connection<THandler>
118where
119 THandler: ConnectionHandler,
120{
121 muxing: StreamMuxerBox,
123 handler: THandler,
125 #[expect(deprecated)] negotiating_in: FuturesUnordered<
128 StreamUpgrade<
129 THandler::InboundOpenInfo,
130 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
131 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
132 >,
133 >,
134 #[expect(deprecated)] negotiating_out: FuturesUnordered<
137 StreamUpgrade<
138 THandler::OutboundOpenInfo,
139 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
140 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
141 >,
142 >,
143 shutdown: Shutdown,
145 substream_upgrade_protocol_override: Option<upgrade::Version>,
147 max_negotiating_inbound_streams: usize,
156 #[expect(deprecated)] requested_substreams: FuturesUnordered<
162 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
163 >,
164
165 local_supported_protocols:
166 HashMap<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>, bool>,
167 remote_supported_protocols: HashSet<StreamProtocol>,
168 protocol_buffer: Vec<StreamProtocol>,
169
170 idle_timeout: Duration,
171 stream_counter: ActiveStreamCounter,
172}
173
174#[expect(deprecated)] impl<THandler> fmt::Debug for Connection<THandler>
176where
177 THandler: ConnectionHandler + fmt::Debug,
178 THandler::OutboundOpenInfo: fmt::Debug,
179{
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 f.debug_struct("Connection")
182 .field("handler", &self.handler)
183 .finish()
184 }
185}
186
187impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
188
189impl<THandler> Connection<THandler>
190where
191 THandler: ConnectionHandler,
192{
193 pub(crate) fn new(
196 muxer: StreamMuxerBox,
197 mut handler: THandler,
198 substream_upgrade_protocol_override: Option<upgrade::Version>,
199 max_negotiating_inbound_streams: usize,
200 idle_timeout: Duration,
201 ) -> Self {
202 let initial_protocols = gather_supported_protocols(&handler);
203 let mut buffer = Vec::new();
204
205 if !initial_protocols.is_empty() {
206 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
207 ProtocolsChange::from_initial_protocols(
208 initial_protocols.keys().map(|e| &e.0),
209 &mut buffer,
210 ),
211 ));
212 }
213
214 Connection {
215 muxing: muxer,
216 handler,
217 negotiating_in: Default::default(),
218 negotiating_out: Default::default(),
219 shutdown: Shutdown::None,
220 substream_upgrade_protocol_override,
221 max_negotiating_inbound_streams,
222 requested_substreams: Default::default(),
223 local_supported_protocols: initial_protocols,
224 remote_supported_protocols: Default::default(),
225 protocol_buffer: buffer,
226 idle_timeout,
227 stream_counter: ActiveStreamCounter::default(),
228 }
229 }
230
231 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
233 self.handler.on_behaviour_event(event);
234 }
235
236 pub(crate) fn close(
239 self,
240 ) -> (
241 impl futures::Stream<Item = THandler::ToBehaviour>,
242 impl Future<Output = io::Result<()>>,
243 ) {
244 let Connection {
245 mut handler,
246 muxing,
247 ..
248 } = self;
249
250 (
251 stream::poll_fn(move |cx| handler.poll_close(cx)),
252 muxing.close(),
253 )
254 }
255
256 #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
259 pub(crate) fn poll(
260 self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
263 let Self {
264 requested_substreams,
265 muxing,
266 handler,
267 negotiating_out,
268 negotiating_in,
269 shutdown,
270 max_negotiating_inbound_streams,
271 substream_upgrade_protocol_override,
272 local_supported_protocols: supported_protocols,
273 remote_supported_protocols,
274 protocol_buffer,
275 idle_timeout,
276 stream_counter,
277 ..
278 } = self.get_mut();
279
280 loop {
281 match requested_substreams.poll_next_unpin(cx) {
282 Poll::Ready(Some(Ok(()))) => continue,
283 Poll::Ready(Some(Err(info))) => {
284 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
285 DialUpgradeError {
286 info,
287 error: StreamUpgradeError::Timeout,
288 },
289 ));
290 continue;
291 }
292 Poll::Ready(None) | Poll::Pending => {}
293 }
294
295 match handler.poll(cx) {
297 Poll::Pending => {}
298 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
299 let timeout = *protocol.timeout();
300 let (upgrade, user_data) = protocol.into_upgrade();
301
302 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
303 continue; }
305 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
306 return Poll::Ready(Ok(Event::Handler(event)));
307 }
308 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
309 ProtocolSupport::Added(protocols),
310 )) => {
311 if let Some(added) =
312 ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer)
313 {
314 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
315 remote_supported_protocols.extend(protocol_buffer.drain(..));
316 }
317 continue;
318 }
319 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
320 ProtocolSupport::Removed(protocols),
321 )) => {
322 if let Some(removed) = ProtocolsChange::remove(
323 remote_supported_protocols,
324 protocols,
325 protocol_buffer,
326 ) {
327 handler
328 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
329 }
330 continue;
331 }
332 }
333
334 match negotiating_out.poll_next_unpin(cx) {
337 Poll::Pending | Poll::Ready(None) => {}
338 Poll::Ready(Some((info, Ok(protocol)))) => {
339 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
340 FullyNegotiatedOutbound { protocol, info },
341 ));
342 continue;
343 }
344 Poll::Ready(Some((info, Err(error)))) => {
345 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
346 DialUpgradeError { info, error },
347 ));
348 continue;
349 }
350 }
351
352 match negotiating_in.poll_next_unpin(cx) {
355 Poll::Pending | Poll::Ready(None) => {}
356 Poll::Ready(Some((info, Ok(protocol)))) => {
357 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
358 FullyNegotiatedInbound { protocol, info },
359 ));
360 continue;
361 }
362 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
363 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
364 ListenUpgradeError { info, error },
365 ));
366 continue;
367 }
368 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
369 tracing::debug!("failed to upgrade inbound stream: {e}");
370 continue;
371 }
372 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
373 tracing::debug!("no protocol could be agreed upon for inbound stream");
374 continue;
375 }
376 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
377 tracing::debug!("inbound stream upgrade timed out");
378 continue;
379 }
380 }
381
382 if negotiating_in.is_empty()
386 && negotiating_out.is_empty()
387 && requested_substreams.is_empty()
388 && stream_counter.has_no_active_streams()
389 {
390 if let Some(new_timeout) =
391 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
392 {
393 *shutdown = new_timeout;
394 }
395
396 match shutdown {
397 Shutdown::None => {}
398 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
399 Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
400 Poll::Ready(_) => {
401 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
402 }
403 Poll::Pending => {}
404 },
405 }
406 } else {
407 *shutdown = Shutdown::None;
408 }
409
410 match muxing.poll_unpin(cx)? {
411 Poll::Pending => {}
412 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
413 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
414 new_address: &address,
415 }));
416 return Poll::Ready(Ok(Event::AddressChange(address)));
417 }
418 }
419
420 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
421 match muxing.poll_outbound_unpin(cx)? {
422 Poll::Pending => {}
423 Poll::Ready(substream) => {
424 let (user_data, timeout, upgrade) = requested_substream.extract();
425
426 negotiating_out.push(StreamUpgrade::new_outbound(
427 substream,
428 user_data,
429 timeout,
430 upgrade,
431 *substream_upgrade_protocol_override,
432 stream_counter.clone(),
433 ));
434
435 continue;
438 }
439 }
440 }
441
442 if negotiating_in.len() < *max_negotiating_inbound_streams {
443 match muxing.poll_inbound_unpin(cx)? {
444 Poll::Pending => {}
445 Poll::Ready(substream) => {
446 let protocol = handler.listen_protocol();
447
448 negotiating_in.push(StreamUpgrade::new_inbound(
449 substream,
450 protocol,
451 stream_counter.clone(),
452 ));
453
454 continue;
457 }
458 }
459 }
460
461 let changes = ProtocolsChange::from_full_sets(
462 supported_protocols,
463 handler.listen_protocol().upgrade().protocol_info(),
464 protocol_buffer,
465 );
466
467 if !changes.is_empty() {
468 for change in changes {
469 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
470 }
471 continue;
473 }
474
475 return Poll::Pending;
477 }
478 }
479
480 #[cfg(test)]
481 fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
482 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
483 }
484}
485
486fn gather_supported_protocols<C: ConnectionHandler>(
487 handler: &C,
488) -> HashMap<AsStrHashEq<<C::InboundProtocol as UpgradeInfoSend>::Info>, bool> {
489 handler
490 .listen_protocol()
491 .upgrade()
492 .protocol_info()
493 .map(|info| (AsStrHashEq(info), true))
494 .collect()
495}
496
497fn compute_new_shutdown(
498 handler_keep_alive: bool,
499 current_shutdown: &Shutdown,
500 idle_timeout: Duration,
501) -> Option<Shutdown> {
502 match (current_shutdown, handler_keep_alive) {
503 (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
504 (Shutdown::Later(_), false) => None,
506 (_, false) => {
507 let now = Instant::now();
508 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
509
510 Some(Shutdown::Later(Delay::new(safe_keep_alive)))
511 }
512 (_, true) => Some(Shutdown::None),
513 }
514}
515
516fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
523 while start.checked_add(duration).is_none() {
524 tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
525
526 duration /= 2;
527 }
528
529 duration
530}
531
532#[derive(Debug, Copy, Clone)]
534pub(crate) struct IncomingInfo<'a> {
535 pub(crate) local_addr: &'a Multiaddr,
537 pub(crate) send_back_addr: &'a Multiaddr,
539}
540
541impl IncomingInfo<'_> {
542 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
544 ConnectedPoint::Listener {
545 local_addr: self.local_addr.clone(),
546 send_back_addr: self.send_back_addr.clone(),
547 }
548 }
549}
550
551struct StreamUpgrade<UserData, TOk, TErr> {
552 user_data: Option<UserData>,
553 timeout: Delay,
554 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
555}
556
557impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
558 fn new_outbound<Upgrade>(
559 substream: SubstreamBox,
560 user_data: UserData,
561 timeout: Delay,
562 upgrade: Upgrade,
563 version_override: Option<upgrade::Version>,
564 counter: ActiveStreamCounter,
565 ) -> Self
566 where
567 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
568 {
569 let effective_version = match version_override {
570 Some(version_override) if version_override != upgrade::Version::default() => {
571 tracing::debug!(
572 "Substream upgrade protocol override: {:?} -> {:?}",
573 upgrade::Version::default(),
574 version_override
575 );
576
577 version_override
578 }
579 _ => upgrade::Version::default(),
580 };
581 let protocols = upgrade.protocol_info();
582
583 Self {
584 user_data: Some(user_data),
585 timeout,
586 upgrade: Box::pin(async move {
587 let (info, stream) = multistream_select::dialer_select_proto(
588 substream,
589 protocols,
590 effective_version,
591 )
592 .await
593 .map_err(to_stream_upgrade_error)?;
594
595 let output = upgrade
596 .upgrade_outbound(Stream::new(stream, counter), info)
597 .await
598 .map_err(StreamUpgradeError::Apply)?;
599
600 Ok(output)
601 }),
602 }
603 }
604}
605
606impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
607 fn new_inbound<Upgrade>(
608 substream: SubstreamBox,
609 protocol: SubstreamProtocol<Upgrade, UserData>,
610 counter: ActiveStreamCounter,
611 ) -> Self
612 where
613 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
614 {
615 let timeout = *protocol.timeout();
616 let (upgrade, open_info) = protocol.into_upgrade();
617 let protocols = upgrade.protocol_info();
618
619 Self {
620 user_data: Some(open_info),
621 timeout: Delay::new(timeout),
622 upgrade: Box::pin(async move {
623 let (info, stream) =
624 multistream_select::listener_select_proto(substream, protocols)
625 .await
626 .map_err(to_stream_upgrade_error)?;
627
628 let output = upgrade
629 .upgrade_inbound(Stream::new(stream, counter), info)
630 .await
631 .map_err(StreamUpgradeError::Apply)?;
632
633 Ok(output)
634 }),
635 }
636 }
637}
638
639fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
640 match e {
641 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
642 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
643 NegotiationError::ProtocolError(other) => {
644 StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
645 }
646 }
647}
648
649impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
650
651impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
652 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
653
654 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
655 match self.timeout.poll_unpin(cx) {
656 Poll::Ready(()) => {
657 return Poll::Ready((
658 self.user_data
659 .take()
660 .expect("Future not to be polled again once ready."),
661 Err(StreamUpgradeError::Timeout),
662 ))
663 }
664
665 Poll::Pending => {}
666 }
667
668 let result = futures::ready!(self.upgrade.poll_unpin(cx));
669 let user_data = self
670 .user_data
671 .take()
672 .expect("Future not to be polled again once ready.");
673
674 Poll::Ready((user_data, result))
675 }
676}
677
678enum SubstreamRequested<UserData, Upgrade> {
679 Waiting {
680 user_data: UserData,
681 timeout: Delay,
682 upgrade: Upgrade,
683 extracted_waker: Option<Waker>,
688 },
689 Done,
690}
691
692impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
693 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
694 Self::Waiting {
695 user_data,
696 timeout: Delay::new(timeout),
697 upgrade,
698 extracted_waker: None,
699 }
700 }
701
702 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
703 match mem::replace(self, Self::Done) {
704 SubstreamRequested::Waiting {
705 user_data,
706 timeout,
707 upgrade,
708 extracted_waker: waker,
709 } => {
710 if let Some(waker) = waker {
711 waker.wake();
712 }
713
714 (user_data, timeout, upgrade)
715 }
716 SubstreamRequested::Done => panic!("cannot extract twice"),
717 }
718 }
719}
720
721impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
722
723impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
724 type Output = Result<(), UserData>;
725
726 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
727 let this = self.get_mut();
728
729 match mem::replace(this, Self::Done) {
730 SubstreamRequested::Waiting {
731 user_data,
732 upgrade,
733 mut timeout,
734 ..
735 } => match timeout.poll_unpin(cx) {
736 Poll::Ready(()) => Poll::Ready(Err(user_data)),
737 Poll::Pending => {
738 *this = Self::Waiting {
739 user_data,
740 upgrade,
741 timeout,
742 extracted_waker: Some(cx.waker().clone()),
743 };
744 Poll::Pending
745 }
746 },
747 SubstreamRequested::Done => Poll::Ready(Ok(())),
748 }
749 }
750}
751
752#[derive(Debug)]
762enum Shutdown {
763 None,
765 Asap,
767 Later(Delay),
769}
770
771pub(crate) struct AsStrHashEq<T>(pub(crate) T);
775
776impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}
777
778impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
779 fn eq(&self, other: &Self) -> bool {
780 self.0.as_ref() == other.0.as_ref()
781 }
782}
783
784impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
785 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
786 self.0.as_ref().hash(state)
787 }
788}
789
790#[cfg(test)]
791mod tests {
792 use std::{
793 convert::Infallible,
794 sync::{Arc, Weak},
795 time::Instant,
796 };
797
798 use futures::{future, AsyncRead, AsyncWrite};
799 use libp2p_core::{
800 upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
801 StreamMuxer,
802 };
803 use quickcheck::*;
804 use tracing_subscriber::EnvFilter;
805
806 use super::*;
807 use crate::dummy;
808
809 #[test]
810 fn max_negotiating_inbound_streams() {
811 let _ = tracing_subscriber::fmt()
812 .with_env_filter(EnvFilter::from_default_env())
813 .try_init();
814
815 fn prop(max_negotiating_inbound_streams: u8) {
816 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
817
818 let alive_substream_counter = Arc::new(());
819 let mut connection = Connection::new(
820 StreamMuxerBox::new(DummyStreamMuxer {
821 counter: alive_substream_counter.clone(),
822 }),
823 MockConnectionHandler::new(Duration::from_secs(10)),
824 None,
825 max_negotiating_inbound_streams,
826 Duration::ZERO,
827 );
828
829 let result = connection.poll_noop_waker();
830
831 assert!(result.is_pending());
832 assert_eq!(
833 Arc::weak_count(&alive_substream_counter),
834 max_negotiating_inbound_streams,
835 "Expect no more than the maximum number of allowed streams"
836 );
837 }
838
839 QuickCheck::new().quickcheck(prop as fn(_));
840 }
841
842 #[test]
843 fn outbound_stream_timeout_starts_on_request() {
844 let upgrade_timeout = Duration::from_secs(1);
845 let mut connection = Connection::new(
846 StreamMuxerBox::new(PendingStreamMuxer),
847 MockConnectionHandler::new(upgrade_timeout),
848 None,
849 2,
850 Duration::ZERO,
851 );
852
853 connection.handler.open_new_outbound();
854 let _ = connection.poll_noop_waker();
855
856 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
857
858 let _ = connection.poll_noop_waker();
859
860 assert!(matches!(
861 connection.handler.error.unwrap(),
862 StreamUpgradeError::Timeout
863 ))
864 }
865
866 #[test]
867 fn propagates_changes_to_supported_inbound_protocols() {
868 let mut connection = Connection::new(
869 StreamMuxerBox::new(PendingStreamMuxer),
870 ConfigurableProtocolConnectionHandler::default(),
871 None,
872 0,
873 Duration::ZERO,
874 );
875
876 connection.handler.listen_on(&["/foo"]);
878 let _ = connection.poll_noop_waker();
879
880 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
881 assert!(connection.handler.local_removed.is_empty());
882
883 connection.handler.listen_on(&["/foo", "/bar"]);
885 let _ = connection.poll_noop_waker();
886
887 assert_eq!(
888 connection.handler.local_added,
889 vec![vec!["/foo"], vec!["/bar"]],
890 "expect to only receive an event for the newly added protocols"
891 );
892 assert!(connection.handler.local_removed.is_empty());
893
894 connection.handler.listen_on(&["/bar"]);
896 let _ = connection.poll_noop_waker();
897
898 assert_eq!(
899 connection.handler.local_added,
900 vec![vec!["/foo"], vec!["/bar"]]
901 );
902 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
903 }
904
905 #[test]
906 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
907 let mut connection = Connection::new(
908 StreamMuxerBox::new(PendingStreamMuxer),
909 ConfigurableProtocolConnectionHandler::default(),
910 None,
911 0,
912 Duration::ZERO,
913 );
914
915 connection.handler.remote_adds_support_for(&["/foo"]);
917 let _ = connection.poll_noop_waker();
918
919 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
920 assert!(connection.handler.remote_removed.is_empty());
921
922 connection
924 .handler
925 .remote_adds_support_for(&["/foo", "/bar"]);
926 let _ = connection.poll_noop_waker();
927
928 assert_eq!(
929 connection.handler.remote_added,
930 vec![vec!["/foo"], vec!["/bar"]],
931 "expect to only receive an event for the newly added protocol"
932 );
933 assert!(connection.handler.remote_removed.is_empty());
934
935 connection.handler.remote_removes_support_for(&["/baz"]);
938 let _ = connection.poll_noop_waker();
939
940 assert_eq!(
941 connection.handler.remote_added,
942 vec![vec!["/foo"], vec!["/bar"]]
943 );
944 assert!(&connection.handler.remote_removed.is_empty());
945
946 connection.handler.remote_removes_support_for(&["/bar"]);
948 let _ = connection.poll_noop_waker();
949
950 assert_eq!(
951 connection.handler.remote_added,
952 vec![vec!["/foo"], vec!["/bar"]]
953 );
954 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
955 }
956
957 #[tokio::test]
958 async fn idle_timeout_with_keep_alive_no() {
959 let idle_timeout = Duration::from_millis(100);
960
961 let mut connection = Connection::new(
962 StreamMuxerBox::new(PendingStreamMuxer),
963 dummy::ConnectionHandler,
964 None,
965 0,
966 idle_timeout,
967 );
968
969 assert!(connection.poll_noop_waker().is_pending());
970
971 tokio::time::sleep(idle_timeout).await;
972
973 assert!(matches!(
974 connection.poll_noop_waker(),
975 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
976 ));
977 }
978
979 #[test]
980 fn checked_add_fraction_can_add_u64_max() {
981 let _ = tracing_subscriber::fmt()
982 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
983 .try_init();
984 let start = Instant::now();
985
986 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
987
988 assert!(start.checked_add(duration).is_some())
989 }
990
991 #[test]
992 fn compute_new_shutdown_does_not_panic() {
993 let _ = tracing_subscriber::fmt()
994 .with_env_filter(EnvFilter::from_default_env())
995 .try_init();
996
997 #[derive(Debug)]
998 struct ArbitraryShutdown(Shutdown);
999
1000 impl Clone for ArbitraryShutdown {
1001 fn clone(&self) -> Self {
1002 let shutdown = match self.0 {
1003 Shutdown::None => Shutdown::None,
1004 Shutdown::Asap => Shutdown::Asap,
1005 Shutdown::Later(_) => Shutdown::Later(
1006 Delay::new(Duration::from_secs(1)),
1009 ),
1010 };
1011
1012 ArbitraryShutdown(shutdown)
1013 }
1014 }
1015
1016 impl Arbitrary for ArbitraryShutdown {
1017 fn arbitrary(g: &mut Gen) -> Self {
1018 let shutdown = match g.gen_range(1u8..4) {
1019 1 => Shutdown::None,
1020 2 => Shutdown::Asap,
1021 3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
1022 _ => unreachable!(),
1023 };
1024
1025 Self(shutdown)
1026 }
1027 }
1028
1029 fn prop(
1030 handler_keep_alive: bool,
1031 current_shutdown: ArbitraryShutdown,
1032 idle_timeout: Duration,
1033 ) {
1034 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
1035 }
1036
1037 QuickCheck::new().quickcheck(prop as fn(_, _, _));
1038 }
1039
1040 struct DummyStreamMuxer {
1041 counter: Arc<()>,
1042 }
1043
1044 impl StreamMuxer for DummyStreamMuxer {
1045 type Substream = PendingSubstream;
1046 type Error = Infallible;
1047
1048 fn poll_inbound(
1049 self: Pin<&mut Self>,
1050 _: &mut Context<'_>,
1051 ) -> Poll<Result<Self::Substream, Self::Error>> {
1052 Poll::Ready(Ok(PendingSubstream {
1053 _weak: Arc::downgrade(&self.counter),
1054 }))
1055 }
1056
1057 fn poll_outbound(
1058 self: Pin<&mut Self>,
1059 _: &mut Context<'_>,
1060 ) -> Poll<Result<Self::Substream, Self::Error>> {
1061 Poll::Pending
1062 }
1063
1064 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1065 Poll::Ready(Ok(()))
1066 }
1067
1068 fn poll(
1069 self: Pin<&mut Self>,
1070 _: &mut Context<'_>,
1071 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1072 Poll::Pending
1073 }
1074 }
1075
1076 struct PendingStreamMuxer;
1078
1079 impl StreamMuxer for PendingStreamMuxer {
1080 type Substream = PendingSubstream;
1081 type Error = Infallible;
1082
1083 fn poll_inbound(
1084 self: Pin<&mut Self>,
1085 _: &mut Context<'_>,
1086 ) -> Poll<Result<Self::Substream, Self::Error>> {
1087 Poll::Pending
1088 }
1089
1090 fn poll_outbound(
1091 self: Pin<&mut Self>,
1092 _: &mut Context<'_>,
1093 ) -> Poll<Result<Self::Substream, Self::Error>> {
1094 Poll::Pending
1095 }
1096
1097 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1098 Poll::Pending
1099 }
1100
1101 fn poll(
1102 self: Pin<&mut Self>,
1103 _: &mut Context<'_>,
1104 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1105 Poll::Pending
1106 }
1107 }
1108
1109 struct PendingSubstream {
1110 _weak: Weak<()>,
1111 }
1112
1113 impl AsyncRead for PendingSubstream {
1114 fn poll_read(
1115 self: Pin<&mut Self>,
1116 _cx: &mut Context<'_>,
1117 _buf: &mut [u8],
1118 ) -> Poll<std::io::Result<usize>> {
1119 Poll::Pending
1120 }
1121 }
1122
1123 impl AsyncWrite for PendingSubstream {
1124 fn poll_write(
1125 self: Pin<&mut Self>,
1126 _cx: &mut Context<'_>,
1127 _buf: &[u8],
1128 ) -> Poll<std::io::Result<usize>> {
1129 Poll::Pending
1130 }
1131
1132 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1133 Poll::Pending
1134 }
1135
1136 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1137 Poll::Pending
1138 }
1139 }
1140
1141 struct MockConnectionHandler {
1142 outbound_requested: bool,
1143 error: Option<StreamUpgradeError<Infallible>>,
1144 upgrade_timeout: Duration,
1145 }
1146
1147 impl MockConnectionHandler {
1148 fn new(upgrade_timeout: Duration) -> Self {
1149 Self {
1150 outbound_requested: false,
1151 error: None,
1152 upgrade_timeout,
1153 }
1154 }
1155
1156 fn open_new_outbound(&mut self) {
1157 self.outbound_requested = true;
1158 }
1159 }
1160
1161 #[derive(Default)]
1162 struct ConfigurableProtocolConnectionHandler {
1163 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Infallible>>,
1164 active_protocols: HashSet<StreamProtocol>,
1165 local_added: Vec<Vec<StreamProtocol>>,
1166 local_removed: Vec<Vec<StreamProtocol>>,
1167 remote_added: Vec<Vec<StreamProtocol>>,
1168 remote_removed: Vec<Vec<StreamProtocol>>,
1169 }
1170
1171 impl ConfigurableProtocolConnectionHandler {
1172 fn listen_on(&mut self, protocols: &[&'static str]) {
1173 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1174 }
1175
1176 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1177 self.events
1178 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1179 ProtocolSupport::Added(
1180 protocols.iter().copied().map(StreamProtocol::new).collect(),
1181 ),
1182 ));
1183 }
1184
1185 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1186 self.events
1187 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1188 ProtocolSupport::Removed(
1189 protocols.iter().copied().map(StreamProtocol::new).collect(),
1190 ),
1191 ));
1192 }
1193 }
1194
1195 impl ConnectionHandler for MockConnectionHandler {
1196 type FromBehaviour = Infallible;
1197 type ToBehaviour = Infallible;
1198 type InboundProtocol = DeniedUpgrade;
1199 type OutboundProtocol = DeniedUpgrade;
1200 type InboundOpenInfo = ();
1201 type OutboundOpenInfo = ();
1202
1203 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
1204 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1205 }
1206
1207 fn on_connection_event(
1208 &mut self,
1209 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
1210 ) {
1211 match event {
1212 #[allow(unreachable_patterns)]
1214 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1215 protocol,
1216 ..
1217 }) => libp2p_core::util::unreachable(protocol),
1218 #[allow(unreachable_patterns)]
1220 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1221 protocol,
1222 ..
1223 }) => libp2p_core::util::unreachable(protocol),
1224 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1225 self.error = Some(error)
1226 }
1227 #[allow(unreachable_patterns)]
1229 ConnectionEvent::AddressChange(_)
1230 | ConnectionEvent::ListenUpgradeError(_)
1231 | ConnectionEvent::LocalProtocolsChange(_)
1232 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1233 }
1234 }
1235
1236 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1237 #[allow(unreachable_patterns)]
1239 libp2p_core::util::unreachable(event)
1240 }
1241
1242 fn connection_keep_alive(&self) -> bool {
1243 true
1244 }
1245
1246 fn poll(
1247 &mut self,
1248 _: &mut Context<'_>,
1249 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
1250 if self.outbound_requested {
1251 self.outbound_requested = false;
1252 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1253 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1254 .with_timeout(self.upgrade_timeout),
1255 });
1256 }
1257
1258 Poll::Pending
1259 }
1260 }
1261
1262 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1263 type FromBehaviour = Infallible;
1264 type ToBehaviour = Infallible;
1265 type InboundProtocol = ManyProtocolsUpgrade;
1266 type OutboundProtocol = DeniedUpgrade;
1267 type InboundOpenInfo = ();
1268 type OutboundOpenInfo = ();
1269
1270 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol> {
1271 SubstreamProtocol::new(
1272 ManyProtocolsUpgrade {
1273 protocols: Vec::from_iter(self.active_protocols.clone()),
1274 },
1275 (),
1276 )
1277 }
1278
1279 fn on_connection_event(
1280 &mut self,
1281 event: ConnectionEvent<Self::InboundProtocol, Self::OutboundProtocol>,
1282 ) {
1283 match event {
1284 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1285 self.local_added.push(added.cloned().collect())
1286 }
1287 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1288 self.local_removed.push(removed.cloned().collect())
1289 }
1290 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1291 self.remote_added.push(added.cloned().collect())
1292 }
1293 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1294 self.remote_removed.push(removed.cloned().collect())
1295 }
1296 _ => {}
1297 }
1298 }
1299
1300 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1301 #[allow(unreachable_patterns)]
1303 libp2p_core::util::unreachable(event)
1304 }
1305
1306 fn connection_keep_alive(&self) -> bool {
1307 true
1308 }
1309
1310 fn poll(
1311 &mut self,
1312 _: &mut Context<'_>,
1313 ) -> Poll<ConnectionHandlerEvent<Self::OutboundProtocol, (), Self::ToBehaviour>> {
1314 if let Some(event) = self.events.pop() {
1315 return Poll::Ready(event);
1316 }
1317
1318 Poll::Pending
1319 }
1320 }
1321
1322 struct ManyProtocolsUpgrade {
1323 protocols: Vec<StreamProtocol>,
1324 }
1325
1326 impl UpgradeInfo for ManyProtocolsUpgrade {
1327 type Info = StreamProtocol;
1328 type InfoIter = std::vec::IntoIter<Self::Info>;
1329
1330 fn protocol_info(&self) -> Self::InfoIter {
1331 self.protocols.clone().into_iter()
1332 }
1333 }
1334
1335 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1336 type Output = C;
1337 type Error = Infallible;
1338 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1339
1340 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1341 future::ready(Ok(stream))
1342 }
1343 }
1344
1345 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1346 type Output = C;
1347 type Error = Infallible;
1348 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1349
1350 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1351 future::ready(Ok(stream))
1352 }
1353 }
1354}
1355
1356#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1358enum PendingPoint {
1359 Dialer {
1365 role_override: Endpoint,
1367 port_use: PortUse,
1368 },
1369 Listener {
1371 local_addr: Multiaddr,
1373 send_back_addr: Multiaddr,
1375 },
1376}
1377
1378impl From<ConnectedPoint> for PendingPoint {
1379 fn from(endpoint: ConnectedPoint) -> Self {
1380 match endpoint {
1381 ConnectedPoint::Dialer {
1382 role_override,
1383 port_use,
1384 ..
1385 } => PendingPoint::Dialer {
1386 role_override,
1387 port_use,
1388 },
1389 ConnectedPoint::Listener {
1390 local_addr,
1391 send_back_addr,
1392 } => PendingPoint::Listener {
1393 local_addr,
1394 send_back_addr,
1395 },
1396 }
1397 }
1398}