quic_rpc/transport/
boxed.rs

1//! Boxed transport with concrete types
2
3use std::{
4    fmt::Debug,
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use futures_lite::FutureExt;
11use futures_sink::Sink;
12use futures_util::{future::BoxFuture, SinkExt, Stream, StreamExt, TryStreamExt};
13use pin_project::pin_project;
14
15use super::{ConnectionErrors, StreamTypes};
16use crate::RpcMessage;
17type BoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + Sync + 'a>>;
18
19enum SendSinkInner<T: RpcMessage> {
20    #[cfg(feature = "flume-transport")]
21    Direct(::flume::r#async::SendSink<'static, T>),
22    Boxed(Pin<Box<dyn Sink<T, Error = anyhow::Error> + Send + Sync + 'static>>),
23}
24
25/// A sink that can be used to send messages to the remote end of a channel.
26///
27/// For local channels, this is a thin wrapper around a flume send sink.
28/// For network channels, this contains a boxed sink, since it is reasonable
29/// to assume that in that case the additional overhead of boxing is negligible.
30#[pin_project]
31pub struct SendSink<T: RpcMessage>(SendSinkInner<T>);
32
33impl<T: RpcMessage> SendSink<T> {
34    /// Create a new send sink from a boxed sink
35    pub fn boxed(sink: impl Sink<T, Error = anyhow::Error> + Send + Sync + 'static) -> Self {
36        Self(SendSinkInner::Boxed(Box::pin(sink)))
37    }
38
39    /// Create a new send sink from a direct flume send sink
40    #[cfg(feature = "flume-transport")]
41    pub(crate) fn direct(sink: ::flume::r#async::SendSink<'static, T>) -> Self {
42        Self(SendSinkInner::Direct(sink))
43    }
44}
45
46impl<T: RpcMessage> Sink<T> for SendSink<T> {
47    type Error = anyhow::Error;
48
49    fn poll_ready(
50        self: std::pin::Pin<&mut Self>,
51        cx: &mut std::task::Context<'_>,
52    ) -> Poll<Result<(), Self::Error>> {
53        match self.project().0 {
54            #[cfg(feature = "flume-transport")]
55            SendSinkInner::Direct(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from),
56            SendSinkInner::Boxed(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from),
57        }
58    }
59
60    fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
61        match self.project().0 {
62            #[cfg(feature = "flume-transport")]
63            SendSinkInner::Direct(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from),
64            SendSinkInner::Boxed(sink) => sink.start_send_unpin(item),
65        }
66    }
67
68    fn poll_flush(
69        self: std::pin::Pin<&mut Self>,
70        cx: &mut Context<'_>,
71    ) -> Poll<Result<(), Self::Error>> {
72        match self.project().0 {
73            #[cfg(feature = "flume-transport")]
74            SendSinkInner::Direct(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from),
75            SendSinkInner::Boxed(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from),
76        }
77    }
78
79    fn poll_close(
80        self: std::pin::Pin<&mut Self>,
81        cx: &mut Context<'_>,
82    ) -> Poll<Result<(), Self::Error>> {
83        match self.project().0 {
84            #[cfg(feature = "flume-transport")]
85            SendSinkInner::Direct(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from),
86            SendSinkInner::Boxed(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from),
87        }
88    }
89}
90
91enum RecvStreamInner<T: RpcMessage> {
92    #[cfg(feature = "flume-transport")]
93    Direct(::flume::r#async::RecvStream<'static, T>),
94    Boxed(Pin<Box<dyn Stream<Item = Result<T, anyhow::Error>> + Send + Sync + 'static>>),
95}
96
97/// A stream that can be used to receive messages from the remote end of a channel.
98///
99/// For local channels, this is a thin wrapper around a flume receive stream.
100/// For network channels, this contains a boxed stream, since it is reasonable
101#[pin_project]
102pub struct RecvStream<T: RpcMessage>(RecvStreamInner<T>);
103
104impl<T: RpcMessage> RecvStream<T> {
105    /// Create a new receive stream from a boxed stream
106    pub fn boxed(
107        stream: impl Stream<Item = Result<T, anyhow::Error>> + Send + Sync + 'static,
108    ) -> Self {
109        Self(RecvStreamInner::Boxed(Box::pin(stream)))
110    }
111
112    /// Create a new receive stream from a direct flume receive stream
113    #[cfg(feature = "flume-transport")]
114    pub(crate) fn direct(stream: ::flume::r#async::RecvStream<'static, T>) -> Self {
115        Self(RecvStreamInner::Direct(stream))
116    }
117}
118
119impl<T: RpcMessage> Stream for RecvStream<T> {
120    type Item = Result<T, anyhow::Error>;
121
122    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
123        match self.project().0 {
124            #[cfg(feature = "flume-transport")]
125            RecvStreamInner::Direct(stream) => match stream.poll_next_unpin(cx) {
126                Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))),
127                Poll::Ready(None) => Poll::Ready(None),
128                Poll::Pending => Poll::Pending,
129            },
130            RecvStreamInner::Boxed(stream) => stream.poll_next_unpin(cx),
131        }
132    }
133}
134
135enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> {
136    /// A direct future (todo)
137    #[cfg(feature = "flume-transport")]
138    Direct(super::flume::OpenFuture<In, Out>),
139    /// A boxed future
140    Boxed(BoxFuture<'a, anyhow::Result<(SendSink<Out>, RecvStream<In>)>>),
141}
142
143/// A concrete future for opening a channel
144#[pin_project]
145pub struct OpenFuture<'a, In: RpcMessage, Out: RpcMessage>(OpenFutureInner<'a, In, Out>);
146
147impl<'a, In: RpcMessage, Out: RpcMessage> OpenFuture<'a, In, Out> {
148    #[cfg(feature = "flume-transport")]
149    fn direct(f: super::flume::OpenFuture<In, Out>) -> Self {
150        Self(OpenFutureInner::Direct(f))
151    }
152
153    /// Create a new boxed future
154    pub fn boxed(
155        f: impl Future<Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>> + Send + 'a,
156    ) -> Self {
157        Self(OpenFutureInner::Boxed(Box::pin(f)))
158    }
159}
160
161impl<In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'_, In, Out> {
162    type Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>;
163
164    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
165        match self.project().0 {
166            #[cfg(feature = "flume-transport")]
167            OpenFutureInner::Direct(f) => f
168                .poll(cx)
169                .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0)))
170                .map_err(|e| e.into()),
171            OpenFutureInner::Boxed(f) => f.poll(cx),
172        }
173    }
174}
175
176enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> {
177    /// A direct future
178    #[cfg(feature = "flume-transport")]
179    Direct(super::flume::AcceptFuture<In, Out>),
180    /// A boxed future
181    Boxed(BoxedFuture<'a, anyhow::Result<(SendSink<Out>, RecvStream<In>)>>),
182}
183
184/// Concrete accept future
185#[pin_project]
186pub struct AcceptFuture<'a, In: RpcMessage, Out: RpcMessage>(AcceptFutureInner<'a, In, Out>);
187
188impl<'a, In: RpcMessage, Out: RpcMessage> AcceptFuture<'a, In, Out> {
189    #[cfg(feature = "flume-transport")]
190    fn direct(f: super::flume::AcceptFuture<In, Out>) -> Self {
191        Self(AcceptFutureInner::Direct(f))
192    }
193
194    /// Create a new boxed future
195    pub fn boxed(
196        f: impl Future<Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>> + Send + Sync + 'a,
197    ) -> Self {
198        Self(AcceptFutureInner::Boxed(Box::pin(f)))
199    }
200}
201
202impl<In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'_, In, Out> {
203    type Output = anyhow::Result<(SendSink<Out>, RecvStream<In>)>;
204
205    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
206        match self.project().0 {
207            #[cfg(feature = "flume-transport")]
208            AcceptFutureInner::Direct(f) => f
209                .poll(cx)
210                .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0)))
211                .map_err(|e| e.into()),
212            AcceptFutureInner::Boxed(f) => f.poll(cx),
213        }
214    }
215}
216
217/// A boxable connector
218pub trait BoxableConnector<In: RpcMessage, Out: RpcMessage>: Debug + Send + Sync + 'static {
219    /// Clone the connection and box it
220    fn clone_box(&self) -> Box<dyn BoxableConnector<In, Out>>;
221
222    /// Open a channel to the remote che
223    fn open_boxed(&self) -> OpenFuture<In, Out>;
224}
225
226/// A boxed connector
227#[derive(Debug)]
228pub struct BoxedConnector<In, Out>(Box<dyn BoxableConnector<In, Out>>);
229
230impl<In: RpcMessage, Out: RpcMessage> BoxedConnector<In, Out> {
231    /// Wrap a boxable connector into a box, transforming all the types to concrete types
232    pub fn new(x: impl BoxableConnector<In, Out>) -> Self {
233        Self(Box::new(x))
234    }
235}
236
237impl<In: RpcMessage, Out: RpcMessage> Clone for BoxedConnector<In, Out> {
238    fn clone(&self) -> Self {
239        Self(self.0.clone_box())
240    }
241}
242
243impl<In: RpcMessage, Out: RpcMessage> StreamTypes for BoxedConnector<In, Out> {
244    type In = In;
245    type Out = Out;
246    type RecvStream = RecvStream<In>;
247    type SendSink = SendSink<Out>;
248}
249
250impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for BoxedConnector<In, Out> {
251    type SendError = anyhow::Error;
252    type RecvError = anyhow::Error;
253    type OpenError = anyhow::Error;
254    type AcceptError = anyhow::Error;
255}
256
257impl<In: RpcMessage, Out: RpcMessage> super::Connector for BoxedConnector<In, Out> {
258    async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
259        self.0.open_boxed().await
260    }
261}
262
263/// Stream types for boxed streams
264#[derive(Debug)]
265pub struct BoxedStreamTypes<In, Out> {
266    _p: std::marker::PhantomData<(In, Out)>,
267}
268
269impl<In, Out> Clone for BoxedStreamTypes<In, Out> {
270    fn clone(&self) -> Self {
271        Self {
272            _p: std::marker::PhantomData,
273        }
274    }
275}
276
277impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for BoxedStreamTypes<In, Out> {
278    type SendError = anyhow::Error;
279    type RecvError = anyhow::Error;
280    type OpenError = anyhow::Error;
281    type AcceptError = anyhow::Error;
282}
283
284impl<In: RpcMessage, Out: RpcMessage> StreamTypes for BoxedStreamTypes<In, Out> {
285    type In = In;
286    type Out = Out;
287    type RecvStream = RecvStream<In>;
288    type SendSink = SendSink<Out>;
289}
290
291/// A boxable listener
292pub trait BoxableListener<In: RpcMessage, Out: RpcMessage>: Debug + Send + Sync + 'static {
293    /// Clone the listener and box it
294    fn clone_box(&self) -> Box<dyn BoxableListener<In, Out>>;
295
296    /// Accept a channel from a remote client
297    fn accept_bi_boxed(&self) -> AcceptFuture<In, Out>;
298
299    /// Get the local address
300    fn local_addr(&self) -> &[super::LocalAddr];
301}
302
303/// A boxed listener
304#[derive(Debug)]
305pub struct BoxedListener<In: RpcMessage, Out: RpcMessage>(Box<dyn BoxableListener<In, Out>>);
306
307impl<In: RpcMessage, Out: RpcMessage> BoxedListener<In, Out> {
308    /// Wrap a boxable listener into a box, transforming all the types to concrete types
309    pub fn new(x: impl BoxableListener<In, Out>) -> Self {
310        Self(Box::new(x))
311    }
312}
313
314impl<In: RpcMessage, Out: RpcMessage> Clone for BoxedListener<In, Out> {
315    fn clone(&self) -> Self {
316        Self(self.0.clone_box())
317    }
318}
319
320impl<In: RpcMessage, Out: RpcMessage> StreamTypes for BoxedListener<In, Out> {
321    type In = In;
322    type Out = Out;
323    type RecvStream = RecvStream<In>;
324    type SendSink = SendSink<Out>;
325}
326
327impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for BoxedListener<In, Out> {
328    type SendError = anyhow::Error;
329    type RecvError = anyhow::Error;
330    type OpenError = anyhow::Error;
331    type AcceptError = anyhow::Error;
332}
333
334impl<In: RpcMessage, Out: RpcMessage> super::Listener for BoxedListener<In, Out> {
335    fn accept(
336        &self,
337    ) -> impl Future<Output = Result<(Self::SendSink, Self::RecvStream), Self::AcceptError>> + Send
338    {
339        self.0.accept_bi_boxed()
340    }
341
342    fn local_addr(&self) -> &[super::LocalAddr] {
343        self.0.local_addr()
344    }
345}
346impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out> for BoxedConnector<In, Out> {
347    fn clone_box(&self) -> Box<dyn BoxableConnector<In, Out>> {
348        Box::new(self.clone())
349    }
350
351    fn open_boxed(&self) -> OpenFuture<In, Out> {
352        OpenFuture::boxed(crate::transport::Connector::open(self))
353    }
354}
355
356#[cfg(feature = "quinn-transport")]
357impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out>
358    for super::quinn::QuinnConnector<In, Out>
359{
360    fn clone_box(&self) -> Box<dyn BoxableConnector<In, Out>> {
361        Box::new(self.clone())
362    }
363
364    fn open_boxed(&self) -> OpenFuture<In, Out> {
365        let f = Box::pin(async move {
366            let (send, recv) = super::Connector::open(self).await?;
367            // map the error types to anyhow
368            let send = send.sink_map_err(anyhow::Error::from);
369            let recv = recv.map_err(anyhow::Error::from);
370            // return the boxed streams
371            anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv)))
372        });
373        OpenFuture::boxed(f)
374    }
375}
376
377#[cfg(feature = "quinn-transport")]
378impl<In: RpcMessage, Out: RpcMessage> BoxableListener<In, Out>
379    for super::quinn::QuinnListener<In, Out>
380{
381    fn clone_box(&self) -> Box<dyn BoxableListener<In, Out>> {
382        Box::new(self.clone())
383    }
384
385    fn accept_bi_boxed(&self) -> AcceptFuture<In, Out> {
386        let f = async move {
387            let (send, recv) = super::Listener::accept(self).await?;
388            let send = send.sink_map_err(anyhow::Error::from);
389            let recv = recv.map_err(anyhow::Error::from);
390            anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv)))
391        };
392        AcceptFuture::boxed(f)
393    }
394
395    fn local_addr(&self) -> &[super::LocalAddr] {
396        super::Listener::local_addr(self)
397    }
398}
399
400#[cfg(feature = "iroh-transport")]
401impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out>
402    for super::iroh::IrohConnector<In, Out>
403{
404    fn clone_box(&self) -> Box<dyn BoxableConnector<In, Out>> {
405        Box::new(self.clone())
406    }
407
408    fn open_boxed(&self) -> OpenFuture<In, Out> {
409        let f = Box::pin(async move {
410            let (send, recv) = super::Connector::open(self).await?;
411            // map the error types to anyhow
412            let send = send.sink_map_err(anyhow::Error::from);
413            let recv = recv.map_err(anyhow::Error::from);
414            // return the boxed streams
415            anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv)))
416        });
417        OpenFuture::boxed(f)
418    }
419}
420
421#[cfg(feature = "iroh-transport")]
422impl<In: RpcMessage, Out: RpcMessage> BoxableListener<In, Out>
423    for super::iroh::IrohListener<In, Out>
424{
425    fn clone_box(&self) -> Box<dyn BoxableListener<In, Out>> {
426        Box::new(self.clone())
427    }
428
429    fn accept_bi_boxed(&self) -> AcceptFuture<In, Out> {
430        let f = async move {
431            let (send, recv) = super::Listener::accept(self).await?;
432            let send = send.sink_map_err(anyhow::Error::from);
433            let recv = recv.map_err(anyhow::Error::from);
434            anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv)))
435        };
436        AcceptFuture::boxed(f)
437    }
438
439    fn local_addr(&self) -> &[super::LocalAddr] {
440        super::Listener::local_addr(self)
441    }
442}
443
444#[cfg(feature = "flume-transport")]
445impl<In: RpcMessage, Out: RpcMessage> BoxableConnector<In, Out>
446    for super::flume::FlumeConnector<In, Out>
447{
448    fn clone_box(&self) -> Box<dyn BoxableConnector<In, Out>> {
449        Box::new(self.clone())
450    }
451
452    fn open_boxed(&self) -> OpenFuture<In, Out> {
453        OpenFuture::direct(super::Connector::open(self))
454    }
455}
456
457#[cfg(feature = "flume-transport")]
458impl<In: RpcMessage, Out: RpcMessage> BoxableListener<In, Out>
459    for super::flume::FlumeListener<In, Out>
460{
461    fn clone_box(&self) -> Box<dyn BoxableListener<In, Out>> {
462        Box::new(self.clone())
463    }
464
465    fn accept_bi_boxed(&self) -> AcceptFuture<In, Out> {
466        AcceptFuture::direct(super::Listener::accept(self))
467    }
468
469    fn local_addr(&self) -> &[super::LocalAddr] {
470        super::Listener::local_addr(self)
471    }
472}
473
474impl<In, Out, C> BoxableConnector<In, Out> for super::mapped::MappedConnector<In, Out, C>
475where
476    In: RpcMessage,
477    Out: RpcMessage,
478    C: super::Connector,
479    C::Out: From<Out>,
480    In: TryFrom<C::In>,
481    C::SendError: Into<anyhow::Error>,
482    C::RecvError: Into<anyhow::Error>,
483    C::OpenError: Into<anyhow::Error>,
484{
485    fn clone_box(&self) -> Box<dyn BoxableConnector<In, Out>> {
486        Box::new(self.clone())
487    }
488
489    fn open_boxed(&self) -> OpenFuture<In, Out> {
490        let f = Box::pin(async move {
491            let (send, recv) = super::Connector::open(self).await.map_err(|e| e.into())?;
492            // map the error types to anyhow
493            let send = send.sink_map_err(|e| e.into());
494            let recv = recv.map_err(|e| e.into());
495            // return the boxed streams
496            anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv)))
497        });
498        OpenFuture::boxed(f)
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use crate::Service;
505
506    #[derive(Debug, Clone)]
507    struct FooService;
508
509    impl Service for FooService {
510        type Req = u64;
511        type Res = u64;
512    }
513
514    #[cfg(feature = "flume-transport")]
515    #[tokio::test]
516    async fn box_smoke() {
517        use futures_lite::StreamExt;
518        use futures_util::SinkExt;
519
520        use crate::transport::{Connector, Listener};
521
522        let (server, client) = crate::transport::flume::channel(1);
523        let server = super::BoxedListener::new(server);
524        let client = super::BoxedConnector::new(client);
525        // spawn echo server
526        tokio::spawn(async move {
527            while let Ok((mut send, mut recv)) = server.accept().await {
528                if let Some(Ok(msg)) = recv.next().await {
529                    send.send(msg).await.ok();
530                }
531            }
532            anyhow::Ok(())
533        });
534        if let Ok((mut send, mut recv)) = client.open().await {
535            send.send(1).await.ok();
536            let res = recv.next().await;
537            println!("{:?}", res);
538        }
539    }
540}