1pub mod either;
42mod map_in;
43mod map_out;
44pub mod multi;
45mod one_shot;
46mod pending;
47mod select;
48
49use core::slice;
50use std::{
51 collections::{HashMap, HashSet},
52 error, fmt, io,
53 task::{Context, Poll},
54 time::Duration,
55};
56
57use libp2p_core::Multiaddr;
58pub use map_in::MapInEvent;
59pub use map_out::MapOutEvent;
60pub use one_shot::{OneShotHandler, OneShotHandlerConfig};
61pub use pending::PendingConnectionHandler;
62pub use select::ConnectionHandlerSelect;
63use smallvec::SmallVec;
64
65pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend};
66use crate::{connection::AsStrHashEq, StreamProtocol};
67
68#[expect(deprecated)] pub trait ConnectionHandler: Send + 'static {
103 type FromBehaviour: fmt::Debug + Send + 'static;
107 type ToBehaviour: fmt::Debug + Send + 'static;
111 type InboundProtocol: InboundUpgradeSend;
113 type OutboundProtocol: OutboundUpgradeSend;
115 #[deprecated = "Track data in ConnectionHandler instead."]
117 type InboundOpenInfo: Send + 'static;
118 #[deprecated = "Track data in ConnectionHandler instead."]
120 type OutboundOpenInfo: Send + 'static;
121
122 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
130
131 fn connection_keep_alive(&self) -> bool {
157 false
158 }
159
160 fn poll(
162 &mut self,
163 cx: &mut Context<'_>,
164 ) -> Poll<
165 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
166 >;
167
168 fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
183 Poll::Ready(None)
184 }
185
186 fn map_in_event<TNewIn, TMap>(self, map: TMap) -> MapInEvent<Self, TNewIn, TMap>
188 where
189 Self: Sized,
190 TMap: Fn(&TNewIn) -> Option<&Self::FromBehaviour>,
191 {
192 MapInEvent::new(self, map)
193 }
194
195 fn map_out_event<TMap, TNewOut>(self, map: TMap) -> MapOutEvent<Self, TMap>
197 where
198 Self: Sized,
199 TMap: FnMut(Self::ToBehaviour) -> TNewOut,
200 {
201 MapOutEvent::new(self, map)
202 }
203
204 fn select<TProto2>(self, other: TProto2) -> ConnectionHandlerSelect<Self, TProto2>
207 where
208 Self: Sized,
209 {
210 ConnectionHandlerSelect::new(self, other)
211 }
212
213 fn on_behaviour_event(&mut self, _event: Self::FromBehaviour);
215
216 fn on_connection_event(
217 &mut self,
218 event: ConnectionEvent<
219 Self::InboundProtocol,
220 Self::OutboundProtocol,
221 Self::InboundOpenInfo,
222 Self::OutboundOpenInfo,
223 >,
224 );
225}
226
227#[non_exhaustive]
230pub enum ConnectionEvent<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI = (), OOI = ()> {
231 FullyNegotiatedInbound(FullyNegotiatedInbound<IP, IOI>),
233 FullyNegotiatedOutbound(FullyNegotiatedOutbound<OP, OOI>),
235 AddressChange(AddressChange<'a>),
237 DialUpgradeError(DialUpgradeError<OOI, OP>),
239 ListenUpgradeError(ListenUpgradeError<IOI, IP>),
241 LocalProtocolsChange(ProtocolsChange<'a>),
243 RemoteProtocolsChange(ProtocolsChange<'a>),
245}
246
247impl<IP, OP, IOI, OOI> fmt::Debug for ConnectionEvent<'_, IP, OP, IOI, OOI>
248where
249 IP: InboundUpgradeSend + fmt::Debug,
250 IP::Output: fmt::Debug,
251 IP::Error: fmt::Debug,
252 OP: OutboundUpgradeSend + fmt::Debug,
253 OP::Output: fmt::Debug,
254 OP::Error: fmt::Debug,
255 IOI: fmt::Debug,
256 OOI: fmt::Debug,
257{
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 match self {
260 ConnectionEvent::FullyNegotiatedInbound(v) => {
261 f.debug_tuple("FullyNegotiatedInbound").field(v).finish()
262 }
263 ConnectionEvent::FullyNegotiatedOutbound(v) => {
264 f.debug_tuple("FullyNegotiatedOutbound").field(v).finish()
265 }
266 ConnectionEvent::AddressChange(v) => f.debug_tuple("AddressChange").field(v).finish(),
267 ConnectionEvent::DialUpgradeError(v) => {
268 f.debug_tuple("DialUpgradeError").field(v).finish()
269 }
270 ConnectionEvent::ListenUpgradeError(v) => {
271 f.debug_tuple("ListenUpgradeError").field(v).finish()
272 }
273 ConnectionEvent::LocalProtocolsChange(v) => {
274 f.debug_tuple("LocalProtocolsChange").field(v).finish()
275 }
276 ConnectionEvent::RemoteProtocolsChange(v) => {
277 f.debug_tuple("RemoteProtocolsChange").field(v).finish()
278 }
279 }
280 }
281}
282
283impl<IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI>
284 ConnectionEvent<'_, IP, OP, IOI, OOI>
285{
286 pub fn is_outbound(&self) -> bool {
288 match self {
289 ConnectionEvent::DialUpgradeError(_) | ConnectionEvent::FullyNegotiatedOutbound(_) => {
290 true
291 }
292 ConnectionEvent::FullyNegotiatedInbound(_)
293 | ConnectionEvent::AddressChange(_)
294 | ConnectionEvent::LocalProtocolsChange(_)
295 | ConnectionEvent::RemoteProtocolsChange(_)
296 | ConnectionEvent::ListenUpgradeError(_) => false,
297 }
298 }
299
300 pub fn is_inbound(&self) -> bool {
302 match self {
303 ConnectionEvent::FullyNegotiatedInbound(_) | ConnectionEvent::ListenUpgradeError(_) => {
304 true
305 }
306 ConnectionEvent::FullyNegotiatedOutbound(_)
307 | ConnectionEvent::AddressChange(_)
308 | ConnectionEvent::LocalProtocolsChange(_)
309 | ConnectionEvent::RemoteProtocolsChange(_)
310 | ConnectionEvent::DialUpgradeError(_) => false,
311 }
312 }
313}
314
315#[derive(Debug)]
324pub struct FullyNegotiatedInbound<IP: InboundUpgradeSend, IOI = ()> {
325 pub protocol: IP::Output,
326 pub info: IOI,
327}
328
329#[derive(Debug)]
335pub struct FullyNegotiatedOutbound<OP: OutboundUpgradeSend, OOI = ()> {
336 pub protocol: OP::Output,
337 pub info: OOI,
338}
339
340#[derive(Debug)]
343pub struct AddressChange<'a> {
344 pub new_address: &'a Multiaddr,
345}
346
347#[derive(Debug, Clone)]
350pub enum ProtocolsChange<'a> {
351 Added(ProtocolsAdded<'a>),
352 Removed(ProtocolsRemoved<'a>),
353}
354
355impl<'a> ProtocolsChange<'a> {
356 pub(crate) fn from_initial_protocols<'b, T: AsRef<str> + 'b>(
358 new_protocols: impl IntoIterator<Item = &'b T>,
359 buffer: &'a mut Vec<StreamProtocol>,
360 ) -> Self {
361 buffer.clear();
362 buffer.extend(
363 new_protocols
364 .into_iter()
365 .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()),
366 );
367
368 ProtocolsChange::Added(ProtocolsAdded {
369 protocols: buffer.iter(),
370 })
371 }
372
373 pub(crate) fn add(
377 existing_protocols: &HashSet<StreamProtocol>,
378 to_add: HashSet<StreamProtocol>,
379 buffer: &'a mut Vec<StreamProtocol>,
380 ) -> Option<Self> {
381 buffer.clear();
382 buffer.extend(
383 to_add
384 .into_iter()
385 .filter(|i| !existing_protocols.contains(i)),
386 );
387
388 if buffer.is_empty() {
389 return None;
390 }
391
392 Some(Self::Added(ProtocolsAdded {
393 protocols: buffer.iter(),
394 }))
395 }
396
397 pub(crate) fn remove(
403 existing_protocols: &mut HashSet<StreamProtocol>,
404 to_remove: HashSet<StreamProtocol>,
405 buffer: &'a mut Vec<StreamProtocol>,
406 ) -> Option<Self> {
407 buffer.clear();
408 buffer.extend(
409 to_remove
410 .into_iter()
411 .filter_map(|i| existing_protocols.take(&i)),
412 );
413
414 if buffer.is_empty() {
415 return None;
416 }
417
418 Some(Self::Removed(ProtocolsRemoved {
419 protocols: buffer.iter(),
420 }))
421 }
422
423 pub(crate) fn from_full_sets<T: AsRef<str>>(
426 existing_protocols: &mut HashMap<AsStrHashEq<T>, bool>,
427 new_protocols: impl IntoIterator<Item = T>,
428 buffer: &'a mut Vec<StreamProtocol>,
429 ) -> SmallVec<[Self; 2]> {
430 buffer.clear();
431
432 for v in existing_protocols.values_mut() {
434 *v = false;
435 }
436
437 let mut new_protocol_count = 0; for new_protocol in new_protocols {
439 existing_protocols
440 .entry(AsStrHashEq(new_protocol))
441 .and_modify(|v| *v = true) .or_insert_with_key(|k| {
443 buffer.extend(StreamProtocol::try_from_owned(k.0.as_ref().to_owned()).ok());
445 true
446 });
447 new_protocol_count += 1;
448 }
449
450 if new_protocol_count == existing_protocols.len() && buffer.is_empty() {
451 return SmallVec::new();
452 }
453
454 let num_new_protocols = buffer.len();
455 existing_protocols.retain(|p, &mut is_supported| {
459 if !is_supported {
460 buffer.extend(StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok());
461 }
462
463 is_supported
464 });
465
466 let (added, removed) = buffer.split_at(num_new_protocols);
467 let mut changes = SmallVec::new();
468 if !added.is_empty() {
469 changes.push(ProtocolsChange::Added(ProtocolsAdded {
470 protocols: added.iter(),
471 }));
472 }
473 if !removed.is_empty() {
474 changes.push(ProtocolsChange::Removed(ProtocolsRemoved {
475 protocols: removed.iter(),
476 }));
477 }
478 changes
479 }
480}
481
482#[derive(Debug, Clone)]
484pub struct ProtocolsAdded<'a> {
485 pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
486}
487
488#[derive(Debug, Clone)]
490pub struct ProtocolsRemoved<'a> {
491 pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
492}
493
494impl<'a> Iterator for ProtocolsAdded<'a> {
495 type Item = &'a StreamProtocol;
496 fn next(&mut self) -> Option<Self::Item> {
497 self.protocols.next()
498 }
499}
500
501impl<'a> Iterator for ProtocolsRemoved<'a> {
502 type Item = &'a StreamProtocol;
503 fn next(&mut self) -> Option<Self::Item> {
504 self.protocols.next()
505 }
506}
507
508#[derive(Debug)]
511pub struct DialUpgradeError<OOI, OP: OutboundUpgradeSend> {
512 pub info: OOI,
513 pub error: StreamUpgradeError<OP::Error>,
514}
515
516#[derive(Debug)]
519pub struct ListenUpgradeError<IOI, IP: InboundUpgradeSend> {
520 pub info: IOI,
521 pub error: IP::Error,
522}
523
524#[derive(Copy, Clone, Debug, PartialEq, Eq)]
530pub struct SubstreamProtocol<TUpgrade, TInfo = ()> {
531 upgrade: TUpgrade,
532 info: TInfo,
533 timeout: Duration,
534}
535
536impl<TUpgrade, TInfo> SubstreamProtocol<TUpgrade, TInfo> {
537 pub fn new(upgrade: TUpgrade, info: TInfo) -> Self {
542 SubstreamProtocol {
543 upgrade,
544 info,
545 timeout: Duration::from_secs(10),
546 }
547 }
548
549 pub fn map_upgrade<U, F>(self, f: F) -> SubstreamProtocol<U, TInfo>
551 where
552 F: FnOnce(TUpgrade) -> U,
553 {
554 SubstreamProtocol {
555 upgrade: f(self.upgrade),
556 info: self.info,
557 timeout: self.timeout,
558 }
559 }
560
561 pub fn map_info<U, F>(self, f: F) -> SubstreamProtocol<TUpgrade, U>
563 where
564 F: FnOnce(TInfo) -> U,
565 {
566 SubstreamProtocol {
567 upgrade: self.upgrade,
568 info: f(self.info),
569 timeout: self.timeout,
570 }
571 }
572
573 pub fn with_timeout(mut self, timeout: Duration) -> Self {
575 self.timeout = timeout;
576 self
577 }
578
579 pub fn upgrade(&self) -> &TUpgrade {
581 &self.upgrade
582 }
583
584 pub fn info(&self) -> &TInfo {
586 &self.info
587 }
588
589 pub fn timeout(&self) -> &Duration {
591 &self.timeout
592 }
593
594 pub fn into_upgrade(self) -> (TUpgrade, TInfo) {
596 (self.upgrade, self.info)
597 }
598}
599
600#[derive(Debug, Clone, PartialEq, Eq)]
602#[non_exhaustive]
603pub enum ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom> {
604 OutboundSubstreamRequest {
606 protocol: SubstreamProtocol<TConnectionUpgrade, TOutboundOpenInfo>,
608 },
609 ReportRemoteProtocols(ProtocolSupport),
611
612 NotifyBehaviour(TCustom),
614}
615
616#[derive(Debug, Clone, PartialEq, Eq)]
617pub enum ProtocolSupport {
618 Added(HashSet<StreamProtocol>),
620 Removed(HashSet<StreamProtocol>),
622}
623
624impl<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
626 ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
627{
628 pub fn map_outbound_open_info<F, I>(
631 self,
632 map: F,
633 ) -> ConnectionHandlerEvent<TConnectionUpgrade, I, TCustom>
634 where
635 F: FnOnce(TOutboundOpenInfo) -> I,
636 {
637 match self {
638 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
639 ConnectionHandlerEvent::OutboundSubstreamRequest {
640 protocol: protocol.map_info(map),
641 }
642 }
643 ConnectionHandlerEvent::NotifyBehaviour(val) => {
644 ConnectionHandlerEvent::NotifyBehaviour(val)
645 }
646 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
647 ConnectionHandlerEvent::ReportRemoteProtocols(support)
648 }
649 }
650 }
651
652 pub fn map_protocol<F, I>(self, map: F) -> ConnectionHandlerEvent<I, TOutboundOpenInfo, TCustom>
655 where
656 F: FnOnce(TConnectionUpgrade) -> I,
657 {
658 match self {
659 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
660 ConnectionHandlerEvent::OutboundSubstreamRequest {
661 protocol: protocol.map_upgrade(map),
662 }
663 }
664 ConnectionHandlerEvent::NotifyBehaviour(val) => {
665 ConnectionHandlerEvent::NotifyBehaviour(val)
666 }
667 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
668 ConnectionHandlerEvent::ReportRemoteProtocols(support)
669 }
670 }
671 }
672
673 pub fn map_custom<F, I>(
675 self,
676 map: F,
677 ) -> ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, I>
678 where
679 F: FnOnce(TCustom) -> I,
680 {
681 match self {
682 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
683 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }
684 }
685 ConnectionHandlerEvent::NotifyBehaviour(val) => {
686 ConnectionHandlerEvent::NotifyBehaviour(map(val))
687 }
688 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
689 ConnectionHandlerEvent::ReportRemoteProtocols(support)
690 }
691 }
692 }
693}
694
695#[derive(Debug)]
697pub enum StreamUpgradeError<TUpgrErr> {
698 Timeout,
700 Apply(TUpgrErr),
702 NegotiationFailed,
704 Io(io::Error),
706}
707
708impl<TUpgrErr> StreamUpgradeError<TUpgrErr> {
709 pub fn map_upgrade_err<F, E>(self, f: F) -> StreamUpgradeError<E>
711 where
712 F: FnOnce(TUpgrErr) -> E,
713 {
714 match self {
715 StreamUpgradeError::Timeout => StreamUpgradeError::Timeout,
716 StreamUpgradeError::Apply(e) => StreamUpgradeError::Apply(f(e)),
717 StreamUpgradeError::NegotiationFailed => StreamUpgradeError::NegotiationFailed,
718 StreamUpgradeError::Io(e) => StreamUpgradeError::Io(e),
719 }
720 }
721}
722
723impl<TUpgrErr> fmt::Display for StreamUpgradeError<TUpgrErr>
724where
725 TUpgrErr: error::Error + 'static,
726{
727 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
728 match self {
729 StreamUpgradeError::Timeout => {
730 write!(f, "Timeout error while opening a substream")
731 }
732 StreamUpgradeError::Apply(err) => {
733 write!(f, "Apply: ")?;
734 crate::print_error_chain(f, err)
735 }
736 StreamUpgradeError::NegotiationFailed => {
737 write!(f, "no protocols could be agreed upon")
738 }
739 StreamUpgradeError::Io(e) => {
740 write!(f, "IO error: ")?;
741 crate::print_error_chain(f, e)
742 }
743 }
744 }
745}
746
747impl<TUpgrErr> error::Error for StreamUpgradeError<TUpgrErr>
748where
749 TUpgrErr: error::Error + 'static,
750{
751 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
752 None
753 }
754}
755
756#[cfg(test)]
757mod test {
758 use super::*;
759
760 fn protocol_set_of(s: &'static str) -> HashSet<StreamProtocol> {
761 s.split_whitespace()
762 .map(|p| StreamProtocol::try_from_owned(format!("/{p}")).unwrap())
763 .collect()
764 }
765
766 fn test_remove(
767 existing: &mut HashSet<StreamProtocol>,
768 to_remove: HashSet<StreamProtocol>,
769 ) -> HashSet<StreamProtocol> {
770 ProtocolsChange::remove(existing, to_remove, &mut Vec::new())
771 .into_iter()
772 .flat_map(|c| match c {
773 ProtocolsChange::Added(_) => panic!("unexpected added"),
774 ProtocolsChange::Removed(r) => r.cloned(),
775 })
776 .collect::<HashSet<_>>()
777 }
778
779 #[test]
780 fn test_protocol_remove_subset() {
781 let mut existing = protocol_set_of("a b c");
782 let to_remove = protocol_set_of("a b");
783
784 let change = test_remove(&mut existing, to_remove);
785
786 assert_eq!(existing, protocol_set_of("c"));
787 assert_eq!(change, protocol_set_of("a b"));
788 }
789
790 #[test]
791 fn test_protocol_remove_all() {
792 let mut existing = protocol_set_of("a b c");
793 let to_remove = protocol_set_of("a b c");
794
795 let change = test_remove(&mut existing, to_remove);
796
797 assert_eq!(existing, protocol_set_of(""));
798 assert_eq!(change, protocol_set_of("a b c"));
799 }
800
801 #[test]
802 fn test_protocol_remove_superset() {
803 let mut existing = protocol_set_of("a b c");
804 let to_remove = protocol_set_of("a b c d");
805
806 let change = test_remove(&mut existing, to_remove);
807
808 assert_eq!(existing, protocol_set_of(""));
809 assert_eq!(change, protocol_set_of("a b c"));
810 }
811
812 #[test]
813 fn test_protocol_remove_none() {
814 let mut existing = protocol_set_of("a b c");
815 let to_remove = protocol_set_of("d");
816
817 let change = test_remove(&mut existing, to_remove);
818
819 assert_eq!(existing, protocol_set_of("a b c"));
820 assert_eq!(change, protocol_set_of(""));
821 }
822
823 #[test]
824 fn test_protocol_remove_none_from_empty() {
825 let mut existing = protocol_set_of("");
826 let to_remove = protocol_set_of("d");
827
828 let change = test_remove(&mut existing, to_remove);
829
830 assert_eq!(existing, protocol_set_of(""));
831 assert_eq!(change, protocol_set_of(""));
832 }
833
834 fn test_from_full_sets(
835 existing: HashSet<StreamProtocol>,
836 new: HashSet<StreamProtocol>,
837 ) -> [HashSet<StreamProtocol>; 2] {
838 let mut buffer = Vec::new();
839 let mut existing = existing
840 .iter()
841 .map(|p| (AsStrHashEq(p.as_ref()), true))
842 .collect::<HashMap<_, _>>();
843
844 let changes = ProtocolsChange::from_full_sets(
845 &mut existing,
846 new.iter().map(AsRef::as_ref),
847 &mut buffer,
848 );
849
850 let mut added_changes = HashSet::new();
851 let mut removed_changes = HashSet::new();
852
853 for change in changes {
854 match change {
855 ProtocolsChange::Added(a) => {
856 added_changes.extend(a.cloned());
857 }
858 ProtocolsChange::Removed(r) => {
859 removed_changes.extend(r.cloned());
860 }
861 }
862 }
863
864 [removed_changes, added_changes]
865 }
866
867 #[test]
868 fn test_from_full_stes_subset() {
869 let existing = protocol_set_of("a b c");
870 let new = protocol_set_of("a b");
871
872 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
873
874 assert_eq!(added_changes, protocol_set_of(""));
875 assert_eq!(removed_changes, protocol_set_of("c"));
876 }
877
878 #[test]
879 fn test_from_full_sets_superset() {
880 let existing = protocol_set_of("a b");
881 let new = protocol_set_of("a b c");
882
883 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
884
885 assert_eq!(added_changes, protocol_set_of("c"));
886 assert_eq!(removed_changes, protocol_set_of(""));
887 }
888
889 #[test]
890 fn test_from_full_sets_intersection() {
891 let existing = protocol_set_of("a b c");
892 let new = protocol_set_of("b c d");
893
894 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
895
896 assert_eq!(added_changes, protocol_set_of("d"));
897 assert_eq!(removed_changes, protocol_set_of("a"));
898 }
899
900 #[test]
901 fn test_from_full_sets_disjoint() {
902 let existing = protocol_set_of("a b c");
903 let new = protocol_set_of("d e f");
904
905 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
906
907 assert_eq!(added_changes, protocol_set_of("d e f"));
908 assert_eq!(removed_changes, protocol_set_of("a b c"));
909 }
910
911 #[test]
912 fn test_from_full_sets_empty() {
913 let existing = protocol_set_of("");
914 let new = protocol_set_of("");
915
916 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
917
918 assert_eq!(added_changes, protocol_set_of(""));
919 assert_eq!(removed_changes, protocol_set_of(""));
920 }
921}