1use std::{
22 cmp,
23 task::{Context, Poll},
24};
25
26use either::Either;
27use futures::{future, ready};
28use libp2p_core::upgrade::SelectUpgrade;
29
30use crate::{
31 handler::{
32 AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent,
33 DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, InboundUpgradeSend,
34 ListenUpgradeError, OutboundUpgradeSend, StreamUpgradeError, SubstreamProtocol,
35 },
36 upgrade::SendWrapper,
37};
38
39#[derive(Debug, Clone)]
41pub struct ConnectionHandlerSelect<TProto1, TProto2> {
42 proto1: TProto1,
44 proto2: TProto2,
46}
47
48impl<TProto1, TProto2> ConnectionHandlerSelect<TProto1, TProto2> {
49 pub(crate) fn new(proto1: TProto1, proto2: TProto2) -> Self {
51 ConnectionHandlerSelect { proto1, proto2 }
52 }
53
54 pub fn into_inner(self) -> (TProto1, TProto2) {
55 (self.proto1, self.proto2)
56 }
57}
58
59impl<S1OOI, S2OOI, S1OP, S2OP>
60 FullyNegotiatedOutbound<Either<SendWrapper<S1OP>, SendWrapper<S2OP>>, Either<S1OOI, S2OOI>>
61where
62 S1OP: OutboundUpgradeSend,
63 S2OP: OutboundUpgradeSend,
64 S1OOI: Send + 'static,
65 S2OOI: Send + 'static,
66{
67 pub(crate) fn transpose(
68 self,
69 ) -> Either<FullyNegotiatedOutbound<S1OP, S1OOI>, FullyNegotiatedOutbound<S2OP, S2OOI>> {
70 match self {
71 FullyNegotiatedOutbound {
72 protocol: future::Either::Left(protocol),
73 info: Either::Left(info),
74 } => Either::Left(FullyNegotiatedOutbound { protocol, info }),
75 FullyNegotiatedOutbound {
76 protocol: future::Either::Right(protocol),
77 info: Either::Right(info),
78 } => Either::Right(FullyNegotiatedOutbound { protocol, info }),
79 _ => panic!("wrong API usage: the protocol doesn't match the upgrade info"),
80 }
81 }
82}
83
84impl<S1IP, S1IOI, S2IP, S2IOI>
85 FullyNegotiatedInbound<SelectUpgrade<SendWrapper<S1IP>, SendWrapper<S2IP>>, (S1IOI, S2IOI)>
86where
87 S1IP: InboundUpgradeSend,
88 S2IP: InboundUpgradeSend,
89{
90 pub(crate) fn transpose(
91 self,
92 ) -> Either<FullyNegotiatedInbound<S1IP, S1IOI>, FullyNegotiatedInbound<S2IP, S2IOI>> {
93 match self {
94 FullyNegotiatedInbound {
95 protocol: future::Either::Left(protocol),
96 info: (i1, _i2),
97 } => Either::Left(FullyNegotiatedInbound { protocol, info: i1 }),
98 FullyNegotiatedInbound {
99 protocol: future::Either::Right(protocol),
100 info: (_i1, i2),
101 } => Either::Right(FullyNegotiatedInbound { protocol, info: i2 }),
102 }
103 }
104}
105
106impl<S1OOI, S2OOI, S1OP, S2OP>
107 DialUpgradeError<Either<S1OOI, S2OOI>, Either<SendWrapper<S1OP>, SendWrapper<S2OP>>>
108where
109 S1OP: OutboundUpgradeSend,
110 S2OP: OutboundUpgradeSend,
111 S1OOI: Send + 'static,
112 S2OOI: Send + 'static,
113{
114 pub(crate) fn transpose(
115 self,
116 ) -> Either<DialUpgradeError<S1OOI, S1OP>, DialUpgradeError<S2OOI, S2OP>> {
117 match self {
118 DialUpgradeError {
119 info: Either::Left(info),
120 error: StreamUpgradeError::Apply(Either::Left(err)),
121 } => Either::Left(DialUpgradeError {
122 info,
123 error: StreamUpgradeError::Apply(err),
124 }),
125 DialUpgradeError {
126 info: Either::Right(info),
127 error: StreamUpgradeError::Apply(Either::Right(err)),
128 } => Either::Right(DialUpgradeError {
129 info,
130 error: StreamUpgradeError::Apply(err),
131 }),
132 DialUpgradeError {
133 info: Either::Left(info),
134 error: e,
135 } => Either::Left(DialUpgradeError {
136 info,
137 error: e.map_upgrade_err(|_| panic!("already handled above")),
138 }),
139 DialUpgradeError {
140 info: Either::Right(info),
141 error: e,
142 } => Either::Right(DialUpgradeError {
143 info,
144 error: e.map_upgrade_err(|_| panic!("already handled above")),
145 }),
146 }
147 }
148}
149
150impl<TProto1, TProto2> ConnectionHandlerSelect<TProto1, TProto2>
151where
152 TProto1: ConnectionHandler,
153 TProto2: ConnectionHandler,
154{
155 #[expect(deprecated)] fn on_listen_upgrade_error(
157 &mut self,
158 ListenUpgradeError {
159 info: (i1, i2),
160 error,
161 }: ListenUpgradeError<
162 <Self as ConnectionHandler>::InboundOpenInfo,
163 <Self as ConnectionHandler>::InboundProtocol,
164 >,
165 ) {
166 match error {
167 Either::Left(error) => {
168 self.proto1
169 .on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
170 info: i1,
171 error,
172 }));
173 }
174 Either::Right(error) => {
175 self.proto2
176 .on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
177 info: i2,
178 error,
179 }));
180 }
181 }
182 }
183}
184
185#[expect(deprecated)] impl<TProto1, TProto2> ConnectionHandler for ConnectionHandlerSelect<TProto1, TProto2>
187where
188 TProto1: ConnectionHandler,
189 TProto2: ConnectionHandler,
190{
191 type FromBehaviour = Either<TProto1::FromBehaviour, TProto2::FromBehaviour>;
192 type ToBehaviour = Either<TProto1::ToBehaviour, TProto2::ToBehaviour>;
193 type InboundProtocol = SelectUpgrade<
194 SendWrapper<<TProto1 as ConnectionHandler>::InboundProtocol>,
195 SendWrapper<<TProto2 as ConnectionHandler>::InboundProtocol>,
196 >;
197 type OutboundProtocol =
198 Either<SendWrapper<TProto1::OutboundProtocol>, SendWrapper<TProto2::OutboundProtocol>>;
199 type OutboundOpenInfo = Either<TProto1::OutboundOpenInfo, TProto2::OutboundOpenInfo>;
200 type InboundOpenInfo = (TProto1::InboundOpenInfo, TProto2::InboundOpenInfo);
201
202 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
203 let proto1 = self.proto1.listen_protocol();
204 let proto2 = self.proto2.listen_protocol();
205 let timeout = *std::cmp::max(proto1.timeout(), proto2.timeout());
206 let (u1, i1) = proto1.into_upgrade();
207 let (u2, i2) = proto2.into_upgrade();
208 let choice = SelectUpgrade::new(SendWrapper(u1), SendWrapper(u2));
209 SubstreamProtocol::new(choice, (i1, i2)).with_timeout(timeout)
210 }
211
212 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
213 match event {
214 Either::Left(event) => self.proto1.on_behaviour_event(event),
215 Either::Right(event) => self.proto2.on_behaviour_event(event),
216 }
217 }
218
219 fn connection_keep_alive(&self) -> bool {
220 cmp::max(
221 self.proto1.connection_keep_alive(),
222 self.proto2.connection_keep_alive(),
223 )
224 }
225
226 fn poll(
227 &mut self,
228 cx: &mut Context<'_>,
229 ) -> Poll<
230 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
231 > {
232 match self.proto1.poll(cx) {
233 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
234 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Left(event)));
235 }
236 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
237 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
238 protocol: protocol
239 .map_upgrade(|u| Either::Left(SendWrapper(u)))
240 .map_info(Either::Left),
241 });
242 }
243 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)) => {
244 return Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support));
245 }
246 Poll::Pending => (),
247 };
248
249 match self.proto2.poll(cx) {
250 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
251 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Either::Right(
252 event,
253 )));
254 }
255 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
256 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
257 protocol: protocol
258 .map_upgrade(|u| Either::Right(SendWrapper(u)))
259 .map_info(Either::Right),
260 });
261 }
262 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support)) => {
263 return Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(support));
264 }
265 Poll::Pending => (),
266 };
267
268 Poll::Pending
269 }
270
271 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
272 if let Some(e) = ready!(self.proto1.poll_close(cx)) {
273 return Poll::Ready(Some(Either::Left(e)));
274 }
275
276 if let Some(e) = ready!(self.proto2.poll_close(cx)) {
277 return Poll::Ready(Some(Either::Right(e)));
278 }
279
280 Poll::Ready(None)
281 }
282
283 fn on_connection_event(
284 &mut self,
285 event: ConnectionEvent<
286 Self::InboundProtocol,
287 Self::OutboundProtocol,
288 Self::InboundOpenInfo,
289 Self::OutboundOpenInfo,
290 >,
291 ) {
292 match event {
293 ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => {
294 match fully_negotiated_outbound.transpose() {
295 Either::Left(f) => self
296 .proto1
297 .on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(f)),
298 Either::Right(f) => self
299 .proto2
300 .on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(f)),
301 }
302 }
303 ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => {
304 match fully_negotiated_inbound.transpose() {
305 Either::Left(f) => self
306 .proto1
307 .on_connection_event(ConnectionEvent::FullyNegotiatedInbound(f)),
308 Either::Right(f) => self
309 .proto2
310 .on_connection_event(ConnectionEvent::FullyNegotiatedInbound(f)),
311 }
312 }
313 ConnectionEvent::AddressChange(address) => {
314 self.proto1
315 .on_connection_event(ConnectionEvent::AddressChange(AddressChange {
316 new_address: address.new_address,
317 }));
318
319 self.proto2
320 .on_connection_event(ConnectionEvent::AddressChange(AddressChange {
321 new_address: address.new_address,
322 }));
323 }
324 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
325 match dial_upgrade_error.transpose() {
326 Either::Left(err) => self
327 .proto1
328 .on_connection_event(ConnectionEvent::DialUpgradeError(err)),
329 Either::Right(err) => self
330 .proto2
331 .on_connection_event(ConnectionEvent::DialUpgradeError(err)),
332 }
333 }
334 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
335 self.on_listen_upgrade_error(listen_upgrade_error)
336 }
337 ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
338 self.proto1
339 .on_connection_event(ConnectionEvent::LocalProtocolsChange(
340 supported_protocols.clone(),
341 ));
342 self.proto2
343 .on_connection_event(ConnectionEvent::LocalProtocolsChange(
344 supported_protocols,
345 ));
346 }
347 ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
348 self.proto1
349 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(
350 supported_protocols.clone(),
351 ));
352 self.proto2
353 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(
354 supported_protocols,
355 ));
356 }
357 }
358 }
359}