libp2p_swarm/handler/
multi.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! A [`ConnectionHandler`] implementation that combines multiple other [`ConnectionHandler`]s
22//! indexed by some key.
23
24use std::{
25    cmp,
26    collections::{HashMap, HashSet},
27    error,
28    fmt::{self, Debug},
29    hash::Hash,
30    iter,
31    task::{Context, Poll},
32    time::Duration,
33};
34
35use futures::{future::BoxFuture, prelude::*, ready};
36use rand::Rng;
37
38use crate::{
39    handler::{
40        AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent,
41        DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError,
42        SubstreamProtocol,
43    },
44    upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend},
45    Stream,
46};
47
48/// A [`ConnectionHandler`] for multiple [`ConnectionHandler`]s of the same type.
49#[derive(Clone)]
50pub struct MultiHandler<K, H> {
51    handlers: HashMap<K, H>,
52}
53
54impl<K, H> fmt::Debug for MultiHandler<K, H>
55where
56    K: fmt::Debug + Eq + Hash,
57    H: fmt::Debug,
58{
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("MultiHandler")
61            .field("handlers", &self.handlers)
62            .finish()
63    }
64}
65
66impl<K, H> MultiHandler<K, H>
67where
68    K: Clone + Debug + Hash + Eq + Send + 'static,
69    H: ConnectionHandler,
70{
71    /// Create and populate a `MultiHandler` from the given handler iterator.
72    ///
73    /// It is an error for any two protocols handlers to share the same protocol name.
74    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
75    where
76        I: IntoIterator<Item = (K, H)>,
77    {
78        let m = MultiHandler {
79            handlers: HashMap::from_iter(iter),
80        };
81        uniq_proto_names(
82            m.handlers
83                .values()
84                .map(|h| h.listen_protocol().into_upgrade().0),
85        )?;
86        Ok(m)
87    }
88
89    #[expect(deprecated)] // TODO: Remove when {In, Out}boundOpenInfo is fully removed.
90    fn on_listen_upgrade_error(
91        &mut self,
92        ListenUpgradeError {
93            error: (key, error),
94            mut info,
95        }: ListenUpgradeError<
96            <Self as ConnectionHandler>::InboundOpenInfo,
97            <Self as ConnectionHandler>::InboundProtocol,
98        >,
99    ) {
100        if let Some(h) = self.handlers.get_mut(&key) {
101            if let Some(i) = info.take(&key) {
102                h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
103                    info: i,
104                    error,
105                }));
106            }
107        }
108    }
109}
110
111#[expect(deprecated)] // TODO: Remove when {In, Out}boundOpenInfo is fully removed.
112impl<K, H> ConnectionHandler for MultiHandler<K, H>
113where
114    K: Clone + Debug + Hash + Eq + Send + 'static,
115    H: ConnectionHandler,
116    H::InboundProtocol: InboundUpgradeSend,
117    H::OutboundProtocol: OutboundUpgradeSend,
118{
119    type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
120    type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
121    type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
122    type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
123    type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
124    type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
125
126    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
127        let (upgrade, info, timeout) = self
128            .handlers
129            .iter()
130            .map(|(key, handler)| {
131                let proto = handler.listen_protocol();
132                let timeout = *proto.timeout();
133                let (upgrade, info) = proto.into_upgrade();
134                (key.clone(), (upgrade, info, timeout))
135            })
136            .fold(
137                (Upgrade::new(), Info::new(), Duration::from_secs(0)),
138                |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
139                    upg.upgrades.push((k.clone(), u));
140                    inf.infos.push((k, i));
141                    timeout = cmp::max(timeout, t);
142                    (upg, inf, timeout)
143                },
144            );
145        SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
146    }
147
148    fn on_connection_event(
149        &mut self,
150        event: ConnectionEvent<
151            Self::InboundProtocol,
152            Self::OutboundProtocol,
153            Self::InboundOpenInfo,
154            Self::OutboundOpenInfo,
155        >,
156    ) {
157        match event {
158            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
159                protocol,
160                info: (key, arg),
161            }) => {
162                if let Some(h) = self.handlers.get_mut(&key) {
163                    h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
164                        FullyNegotiatedOutbound {
165                            protocol,
166                            info: arg,
167                        },
168                    ));
169                } else {
170                    tracing::error!("FullyNegotiatedOutbound: no handler for key")
171                }
172            }
173            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
174                protocol: (key, arg),
175                mut info,
176            }) => {
177                if let Some(h) = self.handlers.get_mut(&key) {
178                    if let Some(i) = info.take(&key) {
179                        h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
180                            FullyNegotiatedInbound {
181                                protocol: arg,
182                                info: i,
183                            },
184                        ));
185                    }
186                } else {
187                    tracing::error!("FullyNegotiatedInbound: no handler for key")
188                }
189            }
190            ConnectionEvent::AddressChange(AddressChange { new_address }) => {
191                for h in self.handlers.values_mut() {
192                    h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
193                        new_address,
194                    }));
195                }
196            }
197            ConnectionEvent::DialUpgradeError(DialUpgradeError {
198                info: (key, arg),
199                error,
200            }) => {
201                if let Some(h) = self.handlers.get_mut(&key) {
202                    h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
203                        info: arg,
204                        error,
205                    }));
206                } else {
207                    tracing::error!("DialUpgradeError: no handler for protocol")
208                }
209            }
210            ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
211                self.on_listen_upgrade_error(listen_upgrade_error)
212            }
213            ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
214                for h in self.handlers.values_mut() {
215                    h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
216                        supported_protocols.clone(),
217                    ));
218                }
219            }
220            ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
221                for h in self.handlers.values_mut() {
222                    h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
223                        supported_protocols.clone(),
224                    ));
225                }
226            }
227        }
228    }
229
230    fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
231        if let Some(h) = self.handlers.get_mut(&key) {
232            h.on_behaviour_event(event)
233        } else {
234            tracing::error!("on_behaviour_event: no handler for key")
235        }
236    }
237
238    fn connection_keep_alive(&self) -> bool {
239        self.handlers
240            .values()
241            .map(|h| h.connection_keep_alive())
242            .max()
243            .unwrap_or(false)
244    }
245
246    fn poll(
247        &mut self,
248        cx: &mut Context<'_>,
249    ) -> Poll<
250        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
251    > {
252        // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid
253        // that situation.
254        if self.handlers.is_empty() {
255            return Poll::Pending;
256        }
257
258        // Not always polling handlers in the same order
259        // should give anyone the chance to make progress.
260        let pos = rand::thread_rng().gen_range(0..self.handlers.len());
261
262        for (k, h) in self.handlers.iter_mut().skip(pos) {
263            if let Poll::Ready(e) = h.poll(cx) {
264                let e = e
265                    .map_outbound_open_info(|i| (k.clone(), i))
266                    .map_custom(|p| (k.clone(), p));
267                return Poll::Ready(e);
268            }
269        }
270
271        for (k, h) in self.handlers.iter_mut().take(pos) {
272            if let Poll::Ready(e) = h.poll(cx) {
273                let e = e
274                    .map_outbound_open_info(|i| (k.clone(), i))
275                    .map_custom(|p| (k.clone(), p));
276                return Poll::Ready(e);
277            }
278        }
279
280        Poll::Pending
281    }
282
283    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
284        for (k, h) in self.handlers.iter_mut() {
285            let Some(e) = ready!(h.poll_close(cx)) else {
286                continue;
287            };
288            return Poll::Ready(Some((k.clone(), e)));
289        }
290
291        Poll::Ready(None)
292    }
293}
294
295/// Split [`MultiHandler`] into parts.
296impl<K, H> IntoIterator for MultiHandler<K, H> {
297    type Item = <Self::IntoIter as Iterator>::Item;
298    type IntoIter = std::collections::hash_map::IntoIter<K, H>;
299
300    fn into_iter(self) -> Self::IntoIter {
301        self.handlers.into_iter()
302    }
303}
304
305/// Index and protocol name pair used as `UpgradeInfo::Info`.
306#[derive(Debug, Clone)]
307pub struct IndexedProtoName<H>(usize, H);
308
309impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
310    fn as_ref(&self) -> &str {
311        self.1.as_ref()
312    }
313}
314
315/// The aggregated `InboundOpenInfo`s of supported inbound substream protocols.
316#[derive(Clone)]
317pub struct Info<K, I> {
318    infos: Vec<(K, I)>,
319}
320
321impl<K: Eq, I> Info<K, I> {
322    fn new() -> Self {
323        Info { infos: Vec::new() }
324    }
325
326    pub fn take(&mut self, k: &K) -> Option<I> {
327        if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
328            return Some(self.infos.remove(p).1);
329        }
330        None
331    }
332}
333
334/// Inbound and outbound upgrade for all [`ConnectionHandler`]s.
335#[derive(Clone)]
336pub struct Upgrade<K, H> {
337    upgrades: Vec<(K, H)>,
338}
339
340impl<K, H> Upgrade<K, H> {
341    fn new() -> Self {
342        Upgrade {
343            upgrades: Vec::new(),
344        }
345    }
346}
347
348impl<K, H> fmt::Debug for Upgrade<K, H>
349where
350    K: fmt::Debug + Eq + Hash,
351    H: fmt::Debug,
352{
353    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354        f.debug_struct("Upgrade")
355            .field("upgrades", &self.upgrades)
356            .finish()
357    }
358}
359
360impl<K, H> UpgradeInfoSend for Upgrade<K, H>
361where
362    H: UpgradeInfoSend,
363    K: Send + 'static,
364{
365    type Info = IndexedProtoName<H::Info>;
366    type InfoIter = std::vec::IntoIter<Self::Info>;
367
368    fn protocol_info(&self) -> Self::InfoIter {
369        self.upgrades
370            .iter()
371            .enumerate()
372            .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
373            .map(|(i, h)| IndexedProtoName(i, h))
374            .collect::<Vec<_>>()
375            .into_iter()
376    }
377}
378
379impl<K, H> InboundUpgradeSend for Upgrade<K, H>
380where
381    H: InboundUpgradeSend,
382    K: Send + 'static,
383{
384    type Output = (K, <H as InboundUpgradeSend>::Output);
385    type Error = (K, <H as InboundUpgradeSend>::Error);
386    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
387
388    fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
389        let IndexedProtoName(index, info) = info;
390        let (key, upgrade) = self.upgrades.remove(index);
391        upgrade
392            .upgrade_inbound(resource, info)
393            .map(move |out| match out {
394                Ok(o) => Ok((key, o)),
395                Err(e) => Err((key, e)),
396            })
397            .boxed()
398    }
399}
400
401impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
402where
403    H: OutboundUpgradeSend,
404    K: Send + 'static,
405{
406    type Output = (K, <H as OutboundUpgradeSend>::Output);
407    type Error = (K, <H as OutboundUpgradeSend>::Error);
408    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
409
410    fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
411        let IndexedProtoName(index, info) = info;
412        let (key, upgrade) = self.upgrades.remove(index);
413        upgrade
414            .upgrade_outbound(resource, info)
415            .map(move |out| match out {
416                Ok(o) => Ok((key, o)),
417                Err(e) => Err((key, e)),
418            })
419            .boxed()
420    }
421}
422
423/// Check that no two protocol names are equal.
424fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
425where
426    I: Iterator<Item = T>,
427    T: UpgradeInfoSend,
428{
429    let mut set = HashSet::new();
430    for infos in iter {
431        for i in infos.protocol_info() {
432            let v = Vec::from(i.as_ref());
433            if set.contains(&v) {
434                return Err(DuplicateProtonameError(v));
435            } else {
436                set.insert(v);
437            }
438        }
439    }
440    Ok(())
441}
442
443/// It is an error if two handlers share the same protocol name.
444#[derive(Debug, Clone)]
445pub struct DuplicateProtonameError(Vec<u8>);
446
447impl DuplicateProtonameError {
448    /// The protocol name bytes that occurred in more than one handler.
449    pub fn protocol_name(&self) -> &[u8] {
450        &self.0
451    }
452}
453
454impl fmt::Display for DuplicateProtonameError {
455    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456        if let Ok(s) = std::str::from_utf8(&self.0) {
457            write!(f, "duplicate protocol name: {s}")
458        } else {
459            write!(f, "duplicate protocol name: {:?}", self.0)
460        }
461    }
462}
463
464impl error::Error for DuplicateProtonameError {}