1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
68
69#[cfg(feature = "cbor")]
70pub mod cbor;
71mod codec;
72mod handler;
73#[cfg(feature = "json")]
74pub mod json;
75
76use std::{
77 collections::{HashMap, HashSet, VecDeque},
78 fmt, io,
79 sync::{atomic::AtomicU64, Arc},
80 task::{Context, Poll},
81 time::Duration,
82};
83
84pub use codec::Codec;
85use futures::channel::oneshot;
86use handler::Handler;
87pub use handler::ProtocolSupport;
88use libp2p_core::{transport::PortUse, ConnectedPoint, Endpoint, Multiaddr};
89use libp2p_identity::PeerId;
90use libp2p_swarm::{
91 behaviour::{AddressChange, ConnectionClosed, DialFailure, FromSwarm},
92 dial_opts::DialOpts,
93 ConnectionDenied, ConnectionHandler, ConnectionId, NetworkBehaviour, NotifyHandler,
94 PeerAddresses, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
95};
96use smallvec::SmallVec;
97
98use crate::handler::OutboundMessage;
99
100#[derive(Debug)]
102pub enum Message<TRequest, TResponse, TChannelResponse = TResponse> {
103 Request {
105 request_id: InboundRequestId,
107 request: TRequest,
109 channel: ResponseChannel<TChannelResponse>,
115 },
116 Response {
118 request_id: OutboundRequestId,
122 response: TResponse,
124 },
125}
126
127#[derive(Debug)]
129pub enum Event<TRequest, TResponse, TChannelResponse = TResponse> {
130 Message {
132 peer: PeerId,
134 connection_id: ConnectionId,
136 message: Message<TRequest, TResponse, TChannelResponse>,
138 },
139 OutboundFailure {
141 peer: PeerId,
143 connection_id: ConnectionId,
145 request_id: OutboundRequestId,
147 error: OutboundFailure,
149 },
150 InboundFailure {
152 peer: PeerId,
154 connection_id: ConnectionId,
156 request_id: InboundRequestId,
158 error: InboundFailure,
160 },
161 ResponseSent {
166 peer: PeerId,
168 connection_id: ConnectionId,
170 request_id: InboundRequestId,
172 },
173}
174
175#[derive(Debug)]
178pub enum OutboundFailure {
179 DialFailure,
181 Timeout,
186 ConnectionClosed,
191 UnsupportedProtocols,
193 Io(io::Error),
195}
196
197impl fmt::Display for OutboundFailure {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 match self {
200 OutboundFailure::DialFailure => write!(f, "Failed to dial the requested peer"),
201 OutboundFailure::Timeout => write!(f, "Timeout while waiting for a response"),
202 OutboundFailure::ConnectionClosed => {
203 write!(f, "Connection was closed before a response was received")
204 }
205 OutboundFailure::UnsupportedProtocols => {
206 write!(f, "The remote supports none of the requested protocols")
207 }
208 OutboundFailure::Io(e) => write!(f, "IO error on outbound stream: {e}"),
209 }
210 }
211}
212
213impl std::error::Error for OutboundFailure {}
214
215#[derive(Debug)]
218pub enum InboundFailure {
219 Timeout,
224 ConnectionClosed,
226 UnsupportedProtocols,
229 ResponseOmission,
233 Io(io::Error),
235}
236
237impl fmt::Display for InboundFailure {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 match self {
240 InboundFailure::Timeout => {
241 write!(f, "Timeout while receiving request or sending response")
242 }
243 InboundFailure::ConnectionClosed => {
244 write!(f, "Connection was closed before a response could be sent")
245 }
246 InboundFailure::UnsupportedProtocols => write!(
247 f,
248 "The local peer supports none of the protocols requested by the remote"
249 ),
250 InboundFailure::ResponseOmission => write!(
251 f,
252 "The response channel was dropped without sending a response to the remote"
253 ),
254 InboundFailure::Io(e) => write!(f, "IO error on inbound stream: {e}"),
255 }
256 }
257}
258
259impl std::error::Error for InboundFailure {}
260
261#[derive(Debug)]
265pub struct ResponseChannel<TResponse> {
266 sender: oneshot::Sender<TResponse>,
267}
268
269impl<TResponse> ResponseChannel<TResponse> {
270 pub fn is_open(&self) -> bool {
278 !self.sender.is_canceled()
279 }
280}
281
282#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
287pub struct InboundRequestId(u64);
288
289impl fmt::Display for InboundRequestId {
290 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
291 write!(f, "{}", self.0)
292 }
293}
294
295#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
300pub struct OutboundRequestId(u64);
301
302impl fmt::Display for OutboundRequestId {
303 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
304 write!(f, "{}", self.0)
305 }
306}
307
308#[derive(Debug, Clone)]
310pub struct Config {
311 request_timeout: Duration,
312 max_concurrent_streams: usize,
313}
314
315impl Default for Config {
316 fn default() -> Self {
317 Self {
318 request_timeout: Duration::from_secs(10),
319 max_concurrent_streams: 100,
320 }
321 }
322}
323
324impl Config {
325 #[deprecated(note = "Use `Config::with_request_timeout` for one-liner constructions.")]
327 pub fn set_request_timeout(&mut self, v: Duration) -> &mut Self {
328 self.request_timeout = v;
329 self
330 }
331
332 pub fn with_request_timeout(mut self, v: Duration) -> Self {
334 self.request_timeout = v;
335 self
336 }
337
338 pub fn with_max_concurrent_streams(mut self, num_streams: usize) -> Self {
340 self.max_concurrent_streams = num_streams;
341 self
342 }
343}
344
345pub struct Behaviour<TCodec>
347where
348 TCodec: Codec + Clone + Send + 'static,
349{
350 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
352 outbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
354 next_outbound_request_id: OutboundRequestId,
356 next_inbound_request_id: Arc<AtomicU64>,
358 config: Config,
360 codec: TCodec,
362 pending_events:
364 VecDeque<ToSwarm<Event<TCodec::Request, TCodec::Response>, OutboundMessage<TCodec>>>,
365 connected: HashMap<PeerId, SmallVec<[Connection; 2]>>,
368 addresses: PeerAddresses,
370 pending_outbound_requests: HashMap<PeerId, SmallVec<[OutboundMessage<TCodec>; 10]>>,
373}
374
375impl<TCodec> Behaviour<TCodec>
376where
377 TCodec: Codec + Default + Clone + Send + 'static,
378{
379 pub fn new<I>(protocols: I, cfg: Config) -> Self
382 where
383 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
384 {
385 Self::with_codec(TCodec::default(), protocols, cfg)
386 }
387}
388
389impl<TCodec> Behaviour<TCodec>
390where
391 TCodec: Codec + Clone + Send + 'static,
392{
393 pub fn with_codec<I>(codec: TCodec, protocols: I, cfg: Config) -> Self
396 where
397 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
398 {
399 let mut inbound_protocols = SmallVec::new();
400 let mut outbound_protocols = SmallVec::new();
401 for (p, s) in protocols {
402 if s.inbound() {
403 inbound_protocols.push(p.clone());
404 }
405 if s.outbound() {
406 outbound_protocols.push(p.clone());
407 }
408 }
409 Behaviour {
410 inbound_protocols,
411 outbound_protocols,
412 next_outbound_request_id: OutboundRequestId(1),
413 next_inbound_request_id: Arc::new(AtomicU64::new(1)),
414 config: cfg,
415 codec,
416 pending_events: VecDeque::new(),
417 connected: HashMap::new(),
418 pending_outbound_requests: HashMap::new(),
419 addresses: PeerAddresses::default(),
420 }
421 }
422
423 pub fn send_request(&mut self, peer: &PeerId, request: TCodec::Request) -> OutboundRequestId {
436 let request_id = self.next_outbound_request_id();
437 let request = OutboundMessage {
438 request_id,
439 request,
440 protocols: self.outbound_protocols.clone(),
441 };
442
443 if let Some(request) = self.try_send_request(peer, request) {
444 self.pending_events.push_back(ToSwarm::Dial {
445 opts: DialOpts::peer_id(*peer).build(),
446 });
447 self.pending_outbound_requests
448 .entry(*peer)
449 .or_default()
450 .push(request);
451 }
452
453 request_id
454 }
455
456 pub fn send_response(
468 &mut self,
469 ch: ResponseChannel<TCodec::Response>,
470 rs: TCodec::Response,
471 ) -> Result<(), TCodec::Response> {
472 ch.sender.send(rs)
473 }
474
475 #[deprecated(note = "Use `Swarm::add_peer_address` instead.")]
484 pub fn add_address(&mut self, peer: &PeerId, address: Multiaddr) -> bool {
485 self.addresses.add(*peer, address)
486 }
487
488 #[deprecated(note = "Will be removed with the next breaking release and won't be replaced.")]
490 pub fn remove_address(&mut self, peer: &PeerId, address: &Multiaddr) {
491 self.addresses.remove(peer, address);
492 }
493
494 pub fn is_connected(&self, peer: &PeerId) -> bool {
496 if let Some(connections) = self.connected.get(peer) {
497 !connections.is_empty()
498 } else {
499 false
500 }
501 }
502
503 pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &OutboundRequestId) -> bool {
507 let est_conn = self
509 .connected
510 .get(peer)
511 .map(|cs| {
512 cs.iter()
513 .any(|c| c.pending_outbound_responses.contains(request_id))
514 })
515 .unwrap_or(false);
516 let pen_conn = self
518 .pending_outbound_requests
519 .get(peer)
520 .map(|rps| rps.iter().any(|rp| rp.request_id == *request_id))
521 .unwrap_or(false);
522
523 est_conn || pen_conn
524 }
525
526 pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &InboundRequestId) -> bool {
530 self.connected
531 .get(peer)
532 .map(|cs| {
533 cs.iter()
534 .any(|c| c.pending_inbound_responses.contains(request_id))
535 })
536 .unwrap_or(false)
537 }
538
539 fn next_outbound_request_id(&mut self) -> OutboundRequestId {
541 let request_id = self.next_outbound_request_id;
542 self.next_outbound_request_id.0 += 1;
543 request_id
544 }
545
546 fn try_send_request(
550 &mut self,
551 peer: &PeerId,
552 request: OutboundMessage<TCodec>,
553 ) -> Option<OutboundMessage<TCodec>> {
554 if let Some(connections) = self.connected.get_mut(peer) {
555 if connections.is_empty() {
556 return Some(request);
557 }
558 let ix = (request.request_id.0 as usize) % connections.len();
559 let conn = &mut connections[ix];
560 conn.pending_outbound_responses.insert(request.request_id);
561 self.pending_events.push_back(ToSwarm::NotifyHandler {
562 peer_id: *peer,
563 handler: NotifyHandler::One(conn.id),
564 event: request,
565 });
566 None
567 } else {
568 Some(request)
569 }
570 }
571
572 fn remove_pending_outbound_response(
578 &mut self,
579 peer: &PeerId,
580 connection_id: ConnectionId,
581 request: OutboundRequestId,
582 ) -> bool {
583 self.get_connection_mut(peer, connection_id)
584 .map(|c| c.pending_outbound_responses.remove(&request))
585 .unwrap_or(false)
586 }
587
588 fn remove_pending_inbound_response(
594 &mut self,
595 peer: &PeerId,
596 connection_id: ConnectionId,
597 request: InboundRequestId,
598 ) -> bool {
599 self.get_connection_mut(peer, connection_id)
600 .map(|c| c.pending_inbound_responses.remove(&request))
601 .unwrap_or(false)
602 }
603
604 fn get_connection_mut(
607 &mut self,
608 peer: &PeerId,
609 connection_id: ConnectionId,
610 ) -> Option<&mut Connection> {
611 self.connected
612 .get_mut(peer)
613 .and_then(|connections| connections.iter_mut().find(|c| c.id == connection_id))
614 }
615
616 fn on_address_change(
617 &mut self,
618 AddressChange {
619 peer_id,
620 connection_id,
621 new,
622 ..
623 }: AddressChange,
624 ) {
625 let new_address = match new {
626 ConnectedPoint::Dialer { address, .. } => Some(address.clone()),
627 ConnectedPoint::Listener { .. } => None,
628 };
629 let connections = self
630 .connected
631 .get_mut(&peer_id)
632 .expect("Address change can only happen on an established connection.");
633
634 let connection = connections
635 .iter_mut()
636 .find(|c| c.id == connection_id)
637 .expect("Address change can only happen on an established connection.");
638 connection.remote_address = new_address;
639 }
640
641 fn on_connection_closed(
642 &mut self,
643 ConnectionClosed {
644 peer_id,
645 connection_id,
646 remaining_established,
647 ..
648 }: ConnectionClosed,
649 ) {
650 let connections = self
651 .connected
652 .get_mut(&peer_id)
653 .expect("Expected some established connection to peer before closing.");
654
655 let connection = connections
656 .iter()
657 .position(|c| c.id == connection_id)
658 .map(|p: usize| connections.remove(p))
659 .expect("Expected connection to be established before closing.");
660
661 debug_assert_eq!(connections.is_empty(), remaining_established == 0);
662 if connections.is_empty() {
663 self.connected.remove(&peer_id);
664 }
665
666 for request_id in connection.pending_inbound_responses {
667 self.pending_events
668 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
669 peer: peer_id,
670 connection_id,
671 request_id,
672 error: InboundFailure::ConnectionClosed,
673 }));
674 }
675
676 for request_id in connection.pending_outbound_responses {
677 self.pending_events
678 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
679 peer: peer_id,
680 connection_id,
681 request_id,
682 error: OutboundFailure::ConnectionClosed,
683 }));
684 }
685 }
686
687 fn on_dial_failure(
688 &mut self,
689 DialFailure {
690 peer_id,
691 connection_id,
692 ..
693 }: DialFailure,
694 ) {
695 if let Some(peer) = peer_id {
696 if let Some(pending) = self.pending_outbound_requests.remove(&peer) {
703 for request in pending {
704 self.pending_events
705 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
706 peer,
707 connection_id,
708 request_id: request.request_id,
709 error: OutboundFailure::DialFailure,
710 }));
711 }
712 }
713 }
714 }
715
716 fn preload_new_handler(
719 &mut self,
720 handler: &mut Handler<TCodec>,
721 peer: PeerId,
722 connection_id: ConnectionId,
723 remote_address: Option<Multiaddr>,
724 ) {
725 let mut connection = Connection::new(connection_id, remote_address);
726
727 if let Some(pending_requests) = self.pending_outbound_requests.remove(&peer) {
728 for request in pending_requests {
729 connection
730 .pending_outbound_responses
731 .insert(request.request_id);
732 handler.on_behaviour_event(request);
733 }
734 }
735
736 self.connected.entry(peer).or_default().push(connection);
737 }
738}
739
740impl<TCodec> NetworkBehaviour for Behaviour<TCodec>
741where
742 TCodec: Codec + Send + Clone + 'static,
743{
744 type ConnectionHandler = Handler<TCodec>;
745 type ToSwarm = Event<TCodec::Request, TCodec::Response>;
746
747 fn handle_established_inbound_connection(
748 &mut self,
749 connection_id: ConnectionId,
750 peer: PeerId,
751 _: &Multiaddr,
752 _: &Multiaddr,
753 ) -> Result<THandler<Self>, ConnectionDenied> {
754 let mut handler = Handler::new(
755 self.inbound_protocols.clone(),
756 self.codec.clone(),
757 self.config.request_timeout,
758 self.next_inbound_request_id.clone(),
759 self.config.max_concurrent_streams,
760 );
761
762 self.preload_new_handler(&mut handler, peer, connection_id, None);
763
764 Ok(handler)
765 }
766
767 fn handle_pending_outbound_connection(
768 &mut self,
769 _connection_id: ConnectionId,
770 maybe_peer: Option<PeerId>,
771 _addresses: &[Multiaddr],
772 _effective_role: Endpoint,
773 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
774 let peer = match maybe_peer {
775 None => return Ok(vec![]),
776 Some(peer) => peer,
777 };
778
779 let mut addresses = Vec::new();
780 if let Some(connections) = self.connected.get(&peer) {
781 addresses.extend(connections.iter().filter_map(|c| c.remote_address.clone()))
782 }
783
784 let cached_addrs = self.addresses.get(&peer);
785 addresses.extend(cached_addrs);
786
787 Ok(addresses)
788 }
789
790 fn handle_established_outbound_connection(
791 &mut self,
792 connection_id: ConnectionId,
793 peer: PeerId,
794 remote_address: &Multiaddr,
795 _: Endpoint,
796 _: PortUse,
797 ) -> Result<THandler<Self>, ConnectionDenied> {
798 let mut handler = Handler::new(
799 self.inbound_protocols.clone(),
800 self.codec.clone(),
801 self.config.request_timeout,
802 self.next_inbound_request_id.clone(),
803 self.config.max_concurrent_streams,
804 );
805
806 self.preload_new_handler(
807 &mut handler,
808 peer,
809 connection_id,
810 Some(remote_address.clone()),
811 );
812
813 Ok(handler)
814 }
815
816 fn on_swarm_event(&mut self, event: FromSwarm) {
817 self.addresses.on_swarm_event(&event);
818 match event {
819 FromSwarm::ConnectionEstablished(_) => {}
820 FromSwarm::ConnectionClosed(connection_closed) => {
821 self.on_connection_closed(connection_closed)
822 }
823 FromSwarm::AddressChange(address_change) => self.on_address_change(address_change),
824 FromSwarm::DialFailure(dial_failure) => self.on_dial_failure(dial_failure),
825 _ => {}
826 }
827 }
828
829 fn on_connection_handler_event(
830 &mut self,
831 peer: PeerId,
832 connection_id: ConnectionId,
833 event: THandlerOutEvent<Self>,
834 ) {
835 match event {
836 handler::Event::Response {
837 request_id,
838 response,
839 } => {
840 let removed =
841 self.remove_pending_outbound_response(&peer, connection_id, request_id);
842 debug_assert!(
843 removed,
844 "Expect request_id to be pending before receiving response.",
845 );
846
847 let message = Message::Response {
848 request_id,
849 response,
850 };
851 self.pending_events
852 .push_back(ToSwarm::GenerateEvent(Event::Message {
853 peer,
854 connection_id,
855 message,
856 }));
857 }
858 handler::Event::Request {
859 request_id,
860 request,
861 sender,
862 } => match self.get_connection_mut(&peer, connection_id) {
863 Some(connection) => {
864 let inserted = connection.pending_inbound_responses.insert(request_id);
865 debug_assert!(inserted, "Expect id of new request to be unknown.");
866
867 let channel = ResponseChannel { sender };
868 let message = Message::Request {
869 request_id,
870 request,
871 channel,
872 };
873 self.pending_events
874 .push_back(ToSwarm::GenerateEvent(Event::Message {
875 peer,
876 connection_id,
877 message,
878 }));
879 }
880 None => {
881 tracing::debug!("Connection ({connection_id}) closed after `Event::Request` ({request_id}) has been emitted.");
882 }
883 },
884 handler::Event::ResponseSent(request_id) => {
885 let removed =
886 self.remove_pending_inbound_response(&peer, connection_id, request_id);
887 debug_assert!(
888 removed,
889 "Expect request_id to be pending before response is sent."
890 );
891
892 self.pending_events
893 .push_back(ToSwarm::GenerateEvent(Event::ResponseSent {
894 peer,
895 connection_id,
896 request_id,
897 }));
898 }
899 handler::Event::ResponseOmission(request_id) => {
900 let removed =
901 self.remove_pending_inbound_response(&peer, connection_id, request_id);
902 debug_assert!(
903 removed,
904 "Expect request_id to be pending before response is omitted.",
905 );
906
907 self.pending_events
908 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
909 peer,
910 connection_id,
911 request_id,
912 error: InboundFailure::ResponseOmission,
913 }));
914 }
915 handler::Event::OutboundTimeout(request_id) => {
916 let removed =
917 self.remove_pending_outbound_response(&peer, connection_id, request_id);
918 debug_assert!(
919 removed,
920 "Expect request_id to be pending before request times out."
921 );
922
923 self.pending_events
924 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
925 peer,
926 connection_id,
927 request_id,
928 error: OutboundFailure::Timeout,
929 }));
930 }
931 handler::Event::OutboundUnsupportedProtocols(request_id) => {
932 let removed =
933 self.remove_pending_outbound_response(&peer, connection_id, request_id);
934 debug_assert!(
935 removed,
936 "Expect request_id to be pending before failing to connect.",
937 );
938
939 self.pending_events
940 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
941 peer,
942 connection_id,
943 request_id,
944 error: OutboundFailure::UnsupportedProtocols,
945 }));
946 }
947 handler::Event::OutboundStreamFailed { request_id, error } => {
948 let removed =
949 self.remove_pending_outbound_response(&peer, connection_id, request_id);
950 debug_assert!(removed, "Expect request_id to be pending upon failure");
951
952 self.pending_events
953 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
954 peer,
955 connection_id,
956 request_id,
957 error: OutboundFailure::Io(error),
958 }))
959 }
960 handler::Event::InboundTimeout(request_id) => {
961 let removed =
962 self.remove_pending_inbound_response(&peer, connection_id, request_id);
963
964 if removed {
965 self.pending_events
966 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
967 peer,
968 connection_id,
969 request_id,
970 error: InboundFailure::Timeout,
971 }));
972 } else {
973 tracing::debug!(
975 "Inbound request timeout for an unknown request_id ({request_id})"
976 );
977 }
978 }
979 handler::Event::InboundStreamFailed { request_id, error } => {
980 let removed =
981 self.remove_pending_inbound_response(&peer, connection_id, request_id);
982
983 if removed {
984 self.pending_events
985 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
986 peer,
987 connection_id,
988 request_id,
989 error: InboundFailure::Io(error),
990 }));
991 } else {
992 tracing::debug!("Inbound failure is reported for an unknown request_id ({request_id}): {error}");
994 }
995 }
996 }
997 }
998
999 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self))]
1000 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
1001 if let Some(ev) = self.pending_events.pop_front() {
1002 return Poll::Ready(ev);
1003 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
1004 self.pending_events.shrink_to_fit();
1005 }
1006
1007 Poll::Pending
1008 }
1009}
1010
1011const EMPTY_QUEUE_SHRINK_THRESHOLD: usize = 100;
1016
1017struct Connection {
1019 id: ConnectionId,
1020 remote_address: Option<Multiaddr>,
1021 pending_outbound_responses: HashSet<OutboundRequestId>,
1025 pending_inbound_responses: HashSet<InboundRequestId>,
1028}
1029
1030impl Connection {
1031 fn new(id: ConnectionId, remote_address: Option<Multiaddr>) -> Self {
1032 Self {
1033 id,
1034 remote_address,
1035 pending_outbound_responses: Default::default(),
1036 pending_inbound_responses: Default::default(),
1037 }
1038 }
1039}