1use std::{
3 error, fmt,
4 fmt::Debug,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use futures_lite::Stream;
10use futures_sink::Sink;
11use pin_project::pin_project;
12
13use super::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes};
14
15#[derive(Debug, Clone)]
17pub struct CombinedConnector<A, B> {
18 pub a: Option<A>,
20 pub b: Option<B>,
22}
23
24impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> CombinedConnector<A, B> {
25 pub fn new(a: Option<A>, b: Option<B>) -> Self {
29 Self { a, b }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct CombinedListener<A, B> {
36 pub a: Option<A>,
38 pub b: Option<B>,
40 local_addr: Vec<LocalAddr>,
42}
43
44impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> CombinedListener<A, B> {
45 pub fn new(a: Option<A>, b: Option<B>) -> Self {
52 let mut local_addr = Vec::with_capacity(2);
53 if let Some(a) = &a {
54 local_addr.extend(a.local_addr().iter().cloned())
55 };
56 if let Some(b) = &b {
57 local_addr.extend(b.local_addr().iter().cloned())
58 };
59 Self { a, b, local_addr }
60 }
61
62 pub fn into_inner(self) -> (Option<A>, Option<B>) {
64 (self.a, self.b)
65 }
66}
67
68#[pin_project(project = SendSinkProj)]
70pub enum SendSink<A: StreamTypes, B: StreamTypes> {
71 A(#[pin] A::SendSink),
73 B(#[pin] B::SendSink),
75}
76
77impl<A: StreamTypes, B: StreamTypes<In = A::In, Out = A::Out>> Sink<A::Out> for SendSink<A, B> {
78 type Error = self::SendError<A, B>;
79
80 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81 match self.project() {
82 SendSinkProj::A(sink) => sink.poll_ready(cx).map_err(Self::Error::A),
83 SendSinkProj::B(sink) => sink.poll_ready(cx).map_err(Self::Error::B),
84 }
85 }
86
87 fn start_send(self: Pin<&mut Self>, item: A::Out) -> Result<(), Self::Error> {
88 match self.project() {
89 SendSinkProj::A(sink) => sink.start_send(item).map_err(Self::Error::A),
90 SendSinkProj::B(sink) => sink.start_send(item).map_err(Self::Error::B),
91 }
92 }
93
94 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
95 match self.project() {
96 SendSinkProj::A(sink) => sink.poll_flush(cx).map_err(Self::Error::A),
97 SendSinkProj::B(sink) => sink.poll_flush(cx).map_err(Self::Error::B),
98 }
99 }
100
101 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
102 match self.project() {
103 SendSinkProj::A(sink) => sink.poll_close(cx).map_err(Self::Error::A),
104 SendSinkProj::B(sink) => sink.poll_close(cx).map_err(Self::Error::B),
105 }
106 }
107}
108
109#[pin_project(project = ResStreamProj)]
111pub enum RecvStream<A: StreamTypes, B: StreamTypes> {
112 A(#[pin] A::RecvStream),
114 B(#[pin] B::RecvStream),
116}
117
118impl<A: StreamTypes, B: StreamTypes<In = A::In, Out = A::Out>> Stream for RecvStream<A, B> {
119 type Item = Result<A::In, RecvError<A, B>>;
120
121 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 match self.project() {
123 ResStreamProj::A(stream) => stream.poll_next(cx).map_err(RecvError::<A, B>::A),
124 ResStreamProj::B(stream) => stream.poll_next(cx).map_err(RecvError::<A, B>::B),
125 }
126 }
127}
128
129#[derive(Debug)]
131pub enum SendError<A: ConnectionErrors, B: ConnectionErrors> {
132 A(A::SendError),
134 B(B::SendError),
136}
137
138impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for SendError<A, B> {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 fmt::Debug::fmt(self, f)
141 }
142}
143
144impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for SendError<A, B> {}
145
146#[derive(Debug)]
148pub enum RecvError<A: ConnectionErrors, B: ConnectionErrors> {
149 A(A::RecvError),
151 B(B::RecvError),
153}
154
155impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for RecvError<A, B> {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 fmt::Debug::fmt(self, f)
158 }
159}
160
161impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for RecvError<A, B> {}
162
163#[derive(Debug)]
165pub enum OpenError<A: ConnectionErrors, B: ConnectionErrors> {
166 A(A::OpenError),
168 B(B::OpenError),
170 NoChannel,
172}
173
174impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for OpenError<A, B> {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 fmt::Debug::fmt(self, f)
177 }
178}
179
180impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for OpenError<A, B> {}
181
182#[derive(Debug)]
184pub enum AcceptError<A: ConnectionErrors, B: ConnectionErrors> {
185 A(A::AcceptError),
187 B(B::AcceptError),
189}
190
191impl<A: ConnectionErrors, B: ConnectionErrors> fmt::Display for AcceptError<A, B> {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 fmt::Debug::fmt(self, f)
194 }
195}
196
197impl<A: ConnectionErrors, B: ConnectionErrors> error::Error for AcceptError<A, B> {}
198
199impl<A: ConnectionErrors, B: ConnectionErrors> ConnectionErrors for CombinedConnector<A, B> {
200 type SendError = self::SendError<A, B>;
201 type RecvError = self::RecvError<A, B>;
202 type OpenError = self::OpenError<A, B>;
203 type AcceptError = self::AcceptError<A, B>;
204}
205
206impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> StreamTypes for CombinedConnector<A, B> {
207 type In = A::In;
208 type Out = A::Out;
209 type RecvStream = self::RecvStream<A, B>;
210 type SendSink = self::SendSink<A, B>;
211}
212
213impl<A: Connector, B: Connector<In = A::In, Out = A::Out>> Connector for CombinedConnector<A, B> {
214 async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
215 let this = self.clone();
216 if let Some(a) = this.a {
218 let (send, recv) = a.open().await.map_err(OpenError::A)?;
219 Ok((SendSink::A(send), RecvStream::A(recv)))
220 } else if let Some(b) = this.b {
221 let (send, recv) = b.open().await.map_err(OpenError::B)?;
222 Ok((SendSink::B(send), RecvStream::B(recv)))
223 } else {
224 Err(OpenError::NoChannel)
225 }
226 }
227}
228
229impl<A: ConnectionErrors, B: ConnectionErrors> ConnectionErrors for CombinedListener<A, B> {
230 type SendError = self::SendError<A, B>;
231 type RecvError = self::RecvError<A, B>;
232 type OpenError = self::OpenError<A, B>;
233 type AcceptError = self::AcceptError<A, B>;
234}
235
236impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> StreamTypes for CombinedListener<A, B> {
237 type In = A::In;
238 type Out = A::Out;
239 type RecvStream = self::RecvStream<A, B>;
240 type SendSink = self::SendSink<A, B>;
241}
242
243impl<A: Listener, B: Listener<In = A::In, Out = A::Out>> Listener for CombinedListener<A, B> {
244 async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::AcceptError> {
245 let a_fut = async {
246 if let Some(a) = &self.a {
247 let (send, recv) = a.accept().await.map_err(AcceptError::A)?;
248 Ok((SendSink::A(send), RecvStream::A(recv)))
249 } else {
250 std::future::pending().await
251 }
252 };
253 let b_fut = async {
254 if let Some(b) = &self.b {
255 let (send, recv) = b.accept().await.map_err(AcceptError::B)?;
256 Ok((SendSink::B(send), RecvStream::B(recv)))
257 } else {
258 std::future::pending().await
259 }
260 };
261 async move {
262 tokio::select! {
263 res = a_fut => res,
264 res = b_fut => res,
265 }
266 }
267 .await
268 }
269
270 fn local_addr(&self) -> &[LocalAddr] {
271 &self.local_addr
272 }
273}
274
275#[cfg(test)]
276#[cfg(feature = "flume-transport")]
277mod tests {
278 use crate::transport::{
279 combined::{self, OpenError},
280 flume, Connector,
281 };
282
283 #[tokio::test]
284 async fn open_empty_channel() {
285 let channel = combined::CombinedConnector::<
286 flume::FlumeConnector<(), ()>,
287 flume::FlumeConnector<(), ()>,
288 >::new(None, None);
289 let res = channel.open().await;
290 assert!(matches!(res, Err(OpenError::NoChannel)));
291 }
292}