quic_rpc/transport/
combined.rs

1//! Transport that combines two other transports
2use std::{
3    error, fmt,
4    fmt::Debug,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use futures_lite::Stream;
10use futures_sink::Sink;
11use pin_project::pin_project;
12
13use super::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes};
14
15/// A connection that combines two other connections
16#[derive(Debug, Clone)]
17pub struct CombinedConnector<A, B> {
18    /// First connection
19    pub a: Option<A>,
20    /// Second connection
21    pub b: Option<B>,
22}
23
24impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> CombinedConnector<A, B> {
25    /// Create a combined connection from two other connections
26    ///
27    /// It will always use the first connection that is not `None`.
28    pub fn new(a: Option<A>, b: Option<B>) -> Self {
29        Self { a, b }
30    }
31}
32
33/// An endpoint that combines two other endpoints
34#[derive(Debug, Clone)]
35pub struct CombinedListener<A, B> {
36    /// First endpoint
37    pub a: Option<A>,
38    /// Second endpoint
39    pub b: Option<B>,
40    /// Local addresses from all endpoints
41    local_addr: Vec<LocalAddr>,
42}
43
44impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> CombinedListener<A, B> {
45    /// Create a combined listener from two other listeners
46    ///
47    /// When listening for incoming connections with
48    /// [`Listener::accept`], all configured channels will be listened on,
49    /// and the first to receive a connection will be used. If no channels are configured,
50    /// accept will not throw an error but just wait forever.
51    pub fn new(a: Option<A>, b: Option<B>) -> Self {
52        let mut local_addr = Vec::with_capacity(2);
53        if let Some(a) = &a {
54            local_addr.extend(a.local_addr().iter().cloned())
55        };
56        if let Some(b) = &b {
57            local_addr.extend(b.local_addr().iter().cloned())
58        };
59        Self { a, b, local_addr }
60    }
61
62    /// Get back the inner endpoints
63    pub fn into_inner(self) -> (Option<A>, Option<B>) {
64        (self.a, self.b)
65    }
66}
67
68/// Send sink for combined channels
69#[pin_project(project = SendSinkProj)]
70pub enum SendSink<A: StreamTypes, B: StreamTypes> {
71    /// A variant
72    A(#[pin] A::SendSink),
73    /// B variant
74    B(#[pin] B::SendSink),
75}
76
77impl<A: StreamTypes, B: StreamTypes<In = A::In, Out = A::Out>> Sink<A::Out> for SendSink<A, B> {
78    type Error = self::SendError<A, B>;
79
80    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81        match self.project() {
82            SendSinkProj::A(sink) => sink.poll_ready(cx).map_err(Self::Error::A),
83            SendSinkProj::B(sink) => sink.poll_ready(cx).map_err(Self::Error::B),
84        }
85    }
86
87    fn start_send(self: Pin<&mut Self>, item: A::Out) -> Result<(), Self::Error> {
88        match self.project() {
89            SendSinkProj::A(sink) => sink.start_send(item).map_err(Self::Error::A),
90            SendSinkProj::B(sink) => sink.start_send(item).map_err(Self::Error::B),
91        }
92    }
93
94    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95        match self.project() {
96            SendSinkProj::A(sink) => sink.poll_flush(cx).map_err(Self::Error::A),
97            SendSinkProj::B(sink) => sink.poll_flush(cx).map_err(Self::Error::B),
98        }
99    }
100
101    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102        match self.project() {
103            SendSinkProj::A(sink) => sink.poll_close(cx).map_err(Self::Error::A),
104            SendSinkProj::B(sink) => sink.poll_close(cx).map_err(Self::Error::B),
105        }
106    }
107}
108
109/// RecvStream for combined channels
110#[pin_project(project = ResStreamProj)]
111pub enum RecvStream<A: StreamTypes, B: StreamTypes> {
112    /// A variant
113    A(#[pin] A::RecvStream),
114    /// B variant
115    B(#[pin] B::RecvStream),
116}
117
118impl<A: StreamTypes, B: StreamTypes<In = A::In, Out = A::Out>> Stream for RecvStream<A, B> {
119    type Item = Result<A::In, RecvError<A, B>>;
120
121    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122        match self.project() {
123            ResStreamProj::A(stream) => stream.poll_next(cx).map_err(RecvError::<A, B>::A),
124            ResStreamProj::B(stream) => stream.poll_next(cx).map_err(RecvError::<A, B>::B),
125        }
126    }
127}
128
129/// SendError for combined channels
130#[derive(Debug)]
131pub enum SendError<A: ConnectionErrors, B: ConnectionErrors> {
132    /// A variant
133    A(A::SendError),
134    /// B variant
135    B(B::SendError),
136}
137
138impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for SendError<A, B> {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        fmt::Debug::fmt(self, f)
141    }
142}
143
144impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for SendError<A, B> {}
145
146/// RecvError for combined channels
147#[derive(Debug)]
148pub enum RecvError<A: ConnectionErrors, B: ConnectionErrors> {
149    /// A variant
150    A(A::RecvError),
151    /// B variant
152    B(B::RecvError),
153}
154
155impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for RecvError<A, B> {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        fmt::Debug::fmt(self, f)
158    }
159}
160
161impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for RecvError<A, B> {}
162
163/// OpenError for combined channels
164#[derive(Debug)]
165pub enum OpenError<A: ConnectionErrors, B: ConnectionErrors> {
166    /// A variant
167    A(A::OpenError),
168    /// B variant
169    B(B::OpenError),
170    /// None of the two channels is configured
171    NoChannel,
172}
173
174impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for OpenError<A, B> {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        fmt::Debug::fmt(self, f)
177    }
178}
179
180impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for OpenError<A, B> {}
181
182/// AcceptError for combined channels
183#[derive(Debug)]
184pub enum AcceptError<A: ConnectionErrors, B: ConnectionErrors> {
185    /// A variant
186    A(A::AcceptError),
187    /// B variant
188    B(B::AcceptError),
189}
190
191impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for AcceptError<A, B> {
192    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193        fmt::Debug::fmt(self, f)
194    }
195}
196
197impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for AcceptError<A, B> {}
198
199impl<A: ConnectionErrors, B: ConnectionErrors> ConnectionErrors for CombinedConnector<A, B> {
200    type SendError = self::SendError<A, B>;
201    type RecvError = self::RecvError<A, B>;
202    type OpenError = self::OpenError<A, B>;
203    type AcceptError = self::AcceptError<A, B>;
204}
205
206impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> StreamTypes for CombinedConnector<A, B> {
207    type In = A::In;
208    type Out = A::Out;
209    type RecvStream = self::RecvStream<A, B>;
210    type SendSink = self::SendSink<A, B>;
211}
212
213impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> Connector for CombinedConnector<A, B> {
214    async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
215        let this = self.clone();
216        // try a first, then b
217        if let Some(a) = this.a {
218            let (send, recv) = a.open().await.map_err(OpenError::A)?;
219            Ok((SendSink::A(send), RecvStream::A(recv)))
220        } else if let Some(b) = this.b {
221            let (send, recv) = b.open().await.map_err(OpenError::B)?;
222            Ok((SendSink::B(send), RecvStream::B(recv)))
223        } else {
224            Err(OpenError::NoChannel)
225        }
226    }
227}
228
229impl<A: ConnectionErrors, B: ConnectionErrors> ConnectionErrors for CombinedListener<A, B> {
230    type SendError = self::SendError<A, B>;
231    type RecvError = self::RecvError<A, B>;
232    type OpenError = self::OpenError<A, B>;
233    type AcceptError = self::AcceptError<A, B>;
234}
235
236impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> StreamTypes for CombinedListener<A, B> {
237    type In = A::In;
238    type Out = A::Out;
239    type RecvStream = self::RecvStream<A, B>;
240    type SendSink = self::SendSink<A, B>;
241}
242
243impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> Listener for CombinedListener<A, B> {
244    async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
245        let a_fut = async {
246            if let Some(a) = &self.a {
247                let (send, recv) = a.accept().await.map_err(AcceptError::A)?;
248                Ok((SendSink::A(send), RecvStream::A(recv)))
249            } else {
250                std::future::pending().await
251            }
252        };
253        let b_fut = async {
254            if let Some(b) = &self.b {
255                let (send, recv) = b.accept().await.map_err(AcceptError::B)?;
256                Ok((SendSink::B(send), RecvStream::B(recv)))
257            } else {
258                std::future::pending().await
259            }
260        };
261        async move {
262            tokio::select! {
263                res = a_fut => res,
264                res = b_fut => res,
265            }
266        }
267        .await
268    }
269
270    fn local_addr(&self) -> &[LocalAddr] {
271        &self.local_addr
272    }
273}
274
275#[cfg(test)]
276#[cfg(feature = "flume-transport")]
277mod tests {
278    use crate::transport::{
279        combined::{self, OpenError},
280        flume, Connector,
281    };
282
283    #[tokio::test]
284    async fn open_empty_channel() {
285        let channel = combined::CombinedConnector::<
286            flume::FlumeConnector<(), ()>,
287            flume::FlumeConnector<(), ()>,
288        >::new(None, None);
289        let res = channel.open().await;
290        assert!(matches!(res, Err(OpenError::NoChannel)));
291    }
292}