quic_rpc/transport/
flume.rs

1//! Memory transport implementation using [flume]
2//!
3//! [flume]: https://docs.rs/flume/
4use core::fmt;
5use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, task::Poll};
6
7use futures_lite::{Future, Stream};
8use futures_sink::Sink;
9
10use super::StreamTypes;
11use crate::{
12    transport::{ConnectionErrors, Connector, Listener, LocalAddr},
13    RpcMessage,
14};
15
16/// Error when receiving from a channel
17///
18/// This type has zero inhabitants, so it is always safe to unwrap a result with this error type.
19#[derive(Debug)]
20pub enum RecvError {}
21
22impl fmt::Display for RecvError {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        fmt::Debug::fmt(self, f)
25    }
26}
27
28/// Sink for memory channels
29pub struct SendSink<T: RpcMessage>(pub(crate) flume::r#async::SendSink<'static, T>);
30
31impl<T: RpcMessage> fmt::Debug for SendSink<T> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.debug_struct("SendSink").finish()
34    }
35}
36
37impl<T: RpcMessage> Sink<T> for SendSink<T> {
38    type Error = self::SendError;
39
40    fn poll_ready(
41        mut self: Pin<&mut Self>,
42        cx: &mut std::task::Context<'_>,
43    ) -> Poll<Result<(), Self::Error>> {
44        Pin::new(&mut self.0)
45            .poll_ready(cx)
46            .map_err(|_| SendError::ReceiverDropped)
47    }
48
49    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
50        Pin::new(&mut self.0)
51            .start_send(item)
52            .map_err(|_| SendError::ReceiverDropped)
53    }
54
55    fn poll_flush(
56        mut self: Pin<&mut Self>,
57        cx: &mut std::task::Context<'_>,
58    ) -> Poll<Result<(), Self::Error>> {
59        Pin::new(&mut self.0)
60            .poll_flush(cx)
61            .map_err(|_| SendError::ReceiverDropped)
62    }
63
64    fn poll_close(
65        mut self: Pin<&mut Self>,
66        cx: &mut std::task::Context<'_>,
67    ) -> Poll<Result<(), Self::Error>> {
68        Pin::new(&mut self.0)
69            .poll_close(cx)
70            .map_err(|_| SendError::ReceiverDropped)
71    }
72}
73
74/// Stream for memory channels
75pub struct RecvStream<T: RpcMessage>(pub(crate) flume::r#async::RecvStream<'static, T>);
76
77impl<T: RpcMessage> fmt::Debug for RecvStream<T> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_struct("RecvStream").finish()
80    }
81}
82
83impl<T: RpcMessage> Stream for RecvStream<T> {
84    type Item = result::Result<T, self::RecvError>;
85
86    fn poll_next(
87        mut self: Pin<&mut Self>,
88        cx: &mut std::task::Context<'_>,
89    ) -> Poll<Option<Self::Item>> {
90        match Pin::new(&mut self.0).poll_next(cx) {
91            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
92            Poll::Ready(None) => Poll::Ready(None),
93            Poll::Pending => Poll::Pending,
94        }
95    }
96}
97
98impl error::Error for RecvError {}
99
100/// A flume based listener.
101///
102/// Created using [channel].
103pub struct FlumeListener<In: RpcMessage, Out: RpcMessage> {
104    #[allow(clippy::type_complexity)]
105    stream: flume::Receiver<(SendSink<Out>, RecvStream<In>)>,
106}
107
108impl<In: RpcMessage, Out: RpcMessage> Clone for FlumeListener<In, Out> {
109    fn clone(&self) -> Self {
110        Self {
111            stream: self.stream.clone(),
112        }
113    }
114}
115
116impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for FlumeListener<In, Out> {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        f.debug_struct("FlumeListener")
119            .field("stream", &self.stream)
120            .finish()
121    }
122}
123
124impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for FlumeListener<In, Out> {
125    type SendError = self::SendError;
126    type RecvError = self::RecvError;
127    type OpenError = self::OpenError;
128    type AcceptError = self::AcceptError;
129}
130
131type Socket<In, Out> = (self::SendSink<Out>, self::RecvStream<In>);
132
133/// Future returned by [FlumeConnector::open]
134pub struct OpenFuture<In: RpcMessage, Out: RpcMessage> {
135    inner: flume::r#async::SendFut<'static, Socket<Out, In>>,
136    res: Option<Socket<In, Out>>,
137}
138
139impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for OpenFuture<In, Out> {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        f.debug_struct("OpenFuture").finish()
142    }
143}
144
145impl<In: RpcMessage, Out: RpcMessage> OpenFuture<In, Out> {
146    fn new(inner: flume::r#async::SendFut<'static, Socket<Out, In>>, res: Socket<In, Out>) -> Self {
147        Self {
148            inner,
149            res: Some(res),
150        }
151    }
152}
153
154impl<In: RpcMessage, Out: RpcMessage> Future for OpenFuture<In, Out> {
155    type Output = result::Result<Socket<In, Out>, self::OpenError>;
156
157    fn poll(
158        mut self: Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160    ) -> std::task::Poll<Self::Output> {
161        match Pin::new(&mut self.inner).poll(cx) {
162            Poll::Ready(Ok(())) => self
163                .res
164                .take()
165                .map(|x| Poll::Ready(Ok(x)))
166                .unwrap_or(Poll::Pending),
167            Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenError::RemoteDropped)),
168            Poll::Pending => Poll::Pending,
169        }
170    }
171}
172
173/// Future returned by [FlumeListener::accept]
174pub struct AcceptFuture<In: RpcMessage, Out: RpcMessage> {
175    wrapped: flume::r#async::RecvFut<'static, (SendSink<Out>, RecvStream<In>)>,
176    _p: PhantomData<(In, Out)>,
177}
178
179impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for AcceptFuture<In, Out> {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        f.debug_struct("AcceptFuture").finish()
182    }
183}
184
185impl<In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<In, Out> {
186    type Output = result::Result<(SendSink<Out>, RecvStream<In>), AcceptError>;
187
188    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
189        match Pin::new(&mut self.wrapped).poll(cx) {
190            Poll::Ready(Ok((send, recv))) => Poll::Ready(Ok((send, recv))),
191            Poll::Ready(Err(_)) => Poll::Ready(Err(AcceptError::RemoteDropped)),
192            Poll::Pending => Poll::Pending,
193        }
194    }
195}
196
197impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeListener<In, Out> {
198    type In = In;
199    type Out = Out;
200    type SendSink = SendSink<Out>;
201    type RecvStream = RecvStream<In>;
202}
203
204impl<In: RpcMessage, Out: RpcMessage> Listener for FlumeListener<In, Out> {
205    #[allow(refining_impl_trait)]
206    fn accept(&self) -> AcceptFuture<In, Out> {
207        AcceptFuture {
208            wrapped: self.stream.clone().into_recv_async(),
209            _p: PhantomData,
210        }
211    }
212
213    fn local_addr(&self) -> &[LocalAddr] {
214        &[LocalAddr::Mem]
215    }
216}
217
218impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for FlumeConnector<In, Out> {
219    type SendError = self::SendError;
220    type RecvError = self::RecvError;
221    type OpenError = self::OpenError;
222    type AcceptError = self::AcceptError;
223}
224
225impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeConnector<In, Out> {
226    type In = In;
227    type Out = Out;
228    type SendSink = SendSink<Out>;
229    type RecvStream = RecvStream<In>;
230}
231
232impl<In: RpcMessage, Out: RpcMessage> Connector for FlumeConnector<In, Out> {
233    #[allow(refining_impl_trait)]
234    fn open(&self) -> OpenFuture<In, Out> {
235        let (local_send, remote_recv) = flume::bounded::<Out>(128);
236        let (remote_send, local_recv) = flume::bounded::<In>(128);
237        let remote_chan = (
238            SendSink(remote_send.into_sink()),
239            RecvStream(remote_recv.into_stream()),
240        );
241        let local_chan = (
242            SendSink(local_send.into_sink()),
243            RecvStream(local_recv.into_stream()),
244        );
245        OpenFuture::new(self.sink.clone().into_send_async(remote_chan), local_chan)
246    }
247}
248
249/// A flume based connector.
250///
251/// Created using [channel].
252pub struct FlumeConnector<In: RpcMessage, Out: RpcMessage> {
253    #[allow(clippy::type_complexity)]
254    sink: flume::Sender<(SendSink<In>, RecvStream<Out>)>,
255}
256
257impl<In: RpcMessage, Out: RpcMessage> Clone for FlumeConnector<In, Out> {
258    fn clone(&self) -> Self {
259        Self {
260            sink: self.sink.clone(),
261        }
262    }
263}
264
265impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for FlumeConnector<In, Out> {
266    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267        f.debug_struct("FlumeClientChannel")
268            .field("sink", &self.sink)
269            .finish()
270    }
271}
272
273/// AcceptError for mem channels.
274///
275/// There is not much that can go wrong with mem channels.
276#[derive(Debug)]
277pub enum AcceptError {
278    /// The remote side of the channel was dropped
279    RemoteDropped,
280}
281
282impl fmt::Display for AcceptError {
283    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284        fmt::Debug::fmt(self, f)
285    }
286}
287
288impl error::Error for AcceptError {}
289
290/// SendError for mem channels.
291///
292/// There is not much that can go wrong with mem channels.
293#[derive(Debug)]
294pub enum SendError {
295    /// Receiver was dropped
296    ReceiverDropped,
297}
298
299impl Display for SendError {
300    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301        fmt::Debug::fmt(self, f)
302    }
303}
304
305impl std::error::Error for SendError {}
306
307/// OpenError for mem channels.
308#[derive(Debug)]
309pub enum OpenError {
310    /// The remote side of the channel was dropped
311    RemoteDropped,
312}
313
314impl Display for OpenError {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        fmt::Debug::fmt(self, f)
317    }
318}
319
320impl std::error::Error for OpenError {}
321
322/// CreateChannelError for mem channels.
323///
324/// You can always create a mem channel, so there is no possible error.
325/// Nevertheless we need a type for it.
326#[derive(Debug, Clone, Copy)]
327pub enum CreateChannelError {}
328
329impl Display for CreateChannelError {
330    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331        fmt::Debug::fmt(self, f)
332    }
333}
334
335impl std::error::Error for CreateChannelError {}
336
337/// Create a flume listener and a connected flume connector.
338///
339/// `buffer` the size of the buffer for each channel. Keep this at a low value to get backpressure
340pub fn channel<Req: RpcMessage, Res: RpcMessage>(
341    buffer: usize,
342) -> (FlumeListener<Req, Res>, FlumeConnector<Res, Req>) {
343    let (sink, stream) = flume::bounded(buffer);
344    (FlumeListener { stream }, FlumeConnector { sink })
345}