1use std::{
25 cmp,
26 collections::{HashMap, HashSet},
27 error,
28 fmt::{self, Debug},
29 hash::Hash,
30 iter,
31 task::{Context, Poll},
32 time::Duration,
33};
34
35use futures::{future::BoxFuture, prelude::*, ready};
36use rand::Rng;
37
38use crate::{
39 handler::{
40 AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent,
41 DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError,
42 SubstreamProtocol,
43 },
44 upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend},
45 Stream,
46};
47
48#[derive(Clone)]
50pub struct MultiHandler<K, H> {
51 handlers: HashMap<K, H>,
52}
53
54impl<K, H> fmt::Debug for MultiHandler<K, H>
55where
56 K: fmt::Debug + Eq + Hash,
57 H: fmt::Debug,
58{
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_struct("MultiHandler")
61 .field("handlers", &self.handlers)
62 .finish()
63 }
64}
65
66impl<K, H> MultiHandler<K, H>
67where
68 K: Clone + Debug + Hash + Eq + Send + 'static,
69 H: ConnectionHandler,
70{
71 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
75 where
76 I: IntoIterator<Item = (K, H)>,
77 {
78 let m = MultiHandler {
79 handlers: HashMap::from_iter(iter),
80 };
81 uniq_proto_names(
82 m.handlers
83 .values()
84 .map(|h| h.listen_protocol().into_upgrade().0),
85 )?;
86 Ok(m)
87 }
88
89 #[expect(deprecated)] fn on_listen_upgrade_error(
91 &mut self,
92 ListenUpgradeError {
93 error: (key, error),
94 mut info,
95 }: ListenUpgradeError<
96 <Self as ConnectionHandler>::InboundOpenInfo,
97 <Self as ConnectionHandler>::InboundProtocol,
98 >,
99 ) {
100 if let Some(h) = self.handlers.get_mut(&key) {
101 if let Some(i) = info.take(&key) {
102 h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
103 info: i,
104 error,
105 }));
106 }
107 }
108 }
109}
110
111#[expect(deprecated)] impl<K, H> ConnectionHandler for MultiHandler<K, H>
113where
114 K: Clone + Debug + Hash + Eq + Send + 'static,
115 H: ConnectionHandler,
116 H::InboundProtocol: InboundUpgradeSend,
117 H::OutboundProtocol: OutboundUpgradeSend,
118{
119 type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
120 type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
121 type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
122 type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
123 type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
124 type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
125
126 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
127 let (upgrade, info, timeout) = self
128 .handlers
129 .iter()
130 .map(|(key, handler)| {
131 let proto = handler.listen_protocol();
132 let timeout = *proto.timeout();
133 let (upgrade, info) = proto.into_upgrade();
134 (key.clone(), (upgrade, info, timeout))
135 })
136 .fold(
137 (Upgrade::new(), Info::new(), Duration::from_secs(0)),
138 |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
139 upg.upgrades.push((k.clone(), u));
140 inf.infos.push((k, i));
141 timeout = cmp::max(timeout, t);
142 (upg, inf, timeout)
143 },
144 );
145 SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
146 }
147
148 fn on_connection_event(
149 &mut self,
150 event: ConnectionEvent<
151 Self::InboundProtocol,
152 Self::OutboundProtocol,
153 Self::InboundOpenInfo,
154 Self::OutboundOpenInfo,
155 >,
156 ) {
157 match event {
158 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
159 protocol,
160 info: (key, arg),
161 }) => {
162 if let Some(h) = self.handlers.get_mut(&key) {
163 h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
164 FullyNegotiatedOutbound {
165 protocol,
166 info: arg,
167 },
168 ));
169 } else {
170 tracing::error!("FullyNegotiatedOutbound: no handler for key")
171 }
172 }
173 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
174 protocol: (key, arg),
175 mut info,
176 }) => {
177 if let Some(h) = self.handlers.get_mut(&key) {
178 if let Some(i) = info.take(&key) {
179 h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
180 FullyNegotiatedInbound {
181 protocol: arg,
182 info: i,
183 },
184 ));
185 }
186 } else {
187 tracing::error!("FullyNegotiatedInbound: no handler for key")
188 }
189 }
190 ConnectionEvent::AddressChange(AddressChange { new_address }) => {
191 for h in self.handlers.values_mut() {
192 h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
193 new_address,
194 }));
195 }
196 }
197 ConnectionEvent::DialUpgradeError(DialUpgradeError {
198 info: (key, arg),
199 error,
200 }) => {
201 if let Some(h) = self.handlers.get_mut(&key) {
202 h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
203 info: arg,
204 error,
205 }));
206 } else {
207 tracing::error!("DialUpgradeError: no handler for protocol")
208 }
209 }
210 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
211 self.on_listen_upgrade_error(listen_upgrade_error)
212 }
213 ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
214 for h in self.handlers.values_mut() {
215 h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
216 supported_protocols.clone(),
217 ));
218 }
219 }
220 ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
221 for h in self.handlers.values_mut() {
222 h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
223 supported_protocols.clone(),
224 ));
225 }
226 }
227 }
228 }
229
230 fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
231 if let Some(h) = self.handlers.get_mut(&key) {
232 h.on_behaviour_event(event)
233 } else {
234 tracing::error!("on_behaviour_event: no handler for key")
235 }
236 }
237
238 fn connection_keep_alive(&self) -> bool {
239 self.handlers
240 .values()
241 .map(|h| h.connection_keep_alive())
242 .max()
243 .unwrap_or(false)
244 }
245
246 fn poll(
247 &mut self,
248 cx: &mut Context<'_>,
249 ) -> Poll<
250 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
251 > {
252 if self.handlers.is_empty() {
255 return Poll::Pending;
256 }
257
258 let pos = rand::thread_rng().gen_range(0..self.handlers.len());
261
262 for (k, h) in self.handlers.iter_mut().skip(pos) {
263 if let Poll::Ready(e) = h.poll(cx) {
264 let e = e
265 .map_outbound_open_info(|i| (k.clone(), i))
266 .map_custom(|p| (k.clone(), p));
267 return Poll::Ready(e);
268 }
269 }
270
271 for (k, h) in self.handlers.iter_mut().take(pos) {
272 if let Poll::Ready(e) = h.poll(cx) {
273 let e = e
274 .map_outbound_open_info(|i| (k.clone(), i))
275 .map_custom(|p| (k.clone(), p));
276 return Poll::Ready(e);
277 }
278 }
279
280 Poll::Pending
281 }
282
283 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
284 for (k, h) in self.handlers.iter_mut() {
285 let Some(e) = ready!(h.poll_close(cx)) else {
286 continue;
287 };
288 return Poll::Ready(Some((k.clone(), e)));
289 }
290
291 Poll::Ready(None)
292 }
293}
294
295impl<K, H> IntoIterator for MultiHandler<K, H> {
297 type Item = <Self::IntoIter as Iterator>::Item;
298 type IntoIter = std::collections::hash_map::IntoIter<K, H>;
299
300 fn into_iter(self) -> Self::IntoIter {
301 self.handlers.into_iter()
302 }
303}
304
305#[derive(Debug, Clone)]
307pub struct IndexedProtoName<H>(usize, H);
308
309impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
310 fn as_ref(&self) -> &str {
311 self.1.as_ref()
312 }
313}
314
315#[derive(Clone)]
317pub struct Info<K, I> {
318 infos: Vec<(K, I)>,
319}
320
321impl<K: Eq, I> Info<K, I> {
322 fn new() -> Self {
323 Info { infos: Vec::new() }
324 }
325
326 pub fn take(&mut self, k: &K) -> Option<I> {
327 if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
328 return Some(self.infos.remove(p).1);
329 }
330 None
331 }
332}
333
334#[derive(Clone)]
336pub struct Upgrade<K, H> {
337 upgrades: Vec<(K, H)>,
338}
339
340impl<K, H> Upgrade<K, H> {
341 fn new() -> Self {
342 Upgrade {
343 upgrades: Vec::new(),
344 }
345 }
346}
347
348impl<K, H> fmt::Debug for Upgrade<K, H>
349where
350 K: fmt::Debug + Eq + Hash,
351 H: fmt::Debug,
352{
353 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 f.debug_struct("Upgrade")
355 .field("upgrades", &self.upgrades)
356 .finish()
357 }
358}
359
360impl<K, H> UpgradeInfoSend for Upgrade<K, H>
361where
362 H: UpgradeInfoSend,
363 K: Send + 'static,
364{
365 type Info = IndexedProtoName<H::Info>;
366 type InfoIter = std::vec::IntoIter<Self::Info>;
367
368 fn protocol_info(&self) -> Self::InfoIter {
369 self.upgrades
370 .iter()
371 .enumerate()
372 .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
373 .map(|(i, h)| IndexedProtoName(i, h))
374 .collect::<Vec<_>>()
375 .into_iter()
376 }
377}
378
379impl<K, H> InboundUpgradeSend for Upgrade<K, H>
380where
381 H: InboundUpgradeSend,
382 K: Send + 'static,
383{
384 type Output = (K, <H as InboundUpgradeSend>::Output);
385 type Error = (K, <H as InboundUpgradeSend>::Error);
386 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
387
388 fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
389 let IndexedProtoName(index, info) = info;
390 let (key, upgrade) = self.upgrades.remove(index);
391 upgrade
392 .upgrade_inbound(resource, info)
393 .map(move |out| match out {
394 Ok(o) => Ok((key, o)),
395 Err(e) => Err((key, e)),
396 })
397 .boxed()
398 }
399}
400
401impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
402where
403 H: OutboundUpgradeSend,
404 K: Send + 'static,
405{
406 type Output = (K, <H as OutboundUpgradeSend>::Output);
407 type Error = (K, <H as OutboundUpgradeSend>::Error);
408 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
409
410 fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
411 let IndexedProtoName(index, info) = info;
412 let (key, upgrade) = self.upgrades.remove(index);
413 upgrade
414 .upgrade_outbound(resource, info)
415 .map(move |out| match out {
416 Ok(o) => Ok((key, o)),
417 Err(e) => Err((key, e)),
418 })
419 .boxed()
420 }
421}
422
423fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
425where
426 I: Iterator<Item = T>,
427 T: UpgradeInfoSend,
428{
429 let mut set = HashSet::new();
430 for infos in iter {
431 for i in infos.protocol_info() {
432 let v = Vec::from(i.as_ref());
433 if set.contains(&v) {
434 return Err(DuplicateProtonameError(v));
435 } else {
436 set.insert(v);
437 }
438 }
439 }
440 Ok(())
441}
442
443#[derive(Debug, Clone)]
445pub struct DuplicateProtonameError(Vec<u8>);
446
447impl DuplicateProtonameError {
448 pub fn protocol_name(&self) -> &[u8] {
450 &self.0
451 }
452}
453
454impl fmt::Display for DuplicateProtonameError {
455 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456 if let Ok(s) = std::str::from_utf8(&self.0) {
457 write!(f, "duplicate protocol name: {s}")
458 } else {
459 write!(f, "duplicate protocol name: {:?}", self.0)
460 }
461 }
462}
463
464impl error::Error for DuplicateProtonameError {}