quic_rpc/transport/
mapped.rs

1//! Transport with mapped input and output types.
2use std::{
3    fmt::{Debug, Display},
4    marker::PhantomData,
5    task::{Context, Poll},
6};
7
8use futures_lite::{Stream, StreamExt};
9use futures_util::SinkExt;
10use pin_project::pin_project;
11
12use super::{ConnectionErrors, Connector, StreamTypes};
13use crate::{RpcError, RpcMessage};
14
15/// A connection that maps input and output types
16#[derive(Debug)]
17pub struct MappedConnector<In, Out, C> {
18    inner: C,
19    _p: std::marker::PhantomData<(In, Out)>,
20}
21
22impl<In, Out, C> MappedConnector<In, Out, C>
23where
24    C: Connector,
25    In: TryFrom<C::In>,
26    C::Out: From<Out>,
27{
28    /// Create a new mapped connection
29    pub fn new(inner: C) -> Self {
30        Self {
31            inner,
32            _p: std::marker::PhantomData,
33        }
34    }
35}
36
37impl<In, Out, C> Clone for MappedConnector<In, Out, C>
38where
39    C: Clone,
40{
41    fn clone(&self) -> Self {
42        Self {
43            inner: self.inner.clone(),
44            _p: std::marker::PhantomData,
45        }
46    }
47}
48
49impl<In, Out, C> ConnectionErrors for MappedConnector<In, Out, C>
50where
51    In: RpcMessage,
52    Out: RpcMessage,
53    C: ConnectionErrors,
54{
55    type RecvError = ErrorOrMapError<C::RecvError>;
56    type SendError = C::SendError;
57    type OpenError = C::OpenError;
58    type AcceptError = C::AcceptError;
59}
60
61impl<In, Out, C> StreamTypes for MappedConnector<In, Out, C>
62where
63    C: StreamTypes,
64    In: RpcMessage,
65    Out: RpcMessage,
66    In: TryFrom<C::In>,
67    C::Out: From<Out>,
68{
69    type In = In;
70    type Out = Out;
71    type RecvStream = MappedRecvStream<C::RecvStream, In>;
72    type SendSink = MappedSendSink<C::SendSink, Out, C::Out>;
73}
74
75impl<In, Out, C> Connector for MappedConnector<In, Out, C>
76where
77    C: Connector,
78    In: RpcMessage,
79    Out: RpcMessage,
80    In: TryFrom<C::In>,
81    C::Out: From<Out>,
82{
83    fn open(
84        &self,
85    ) -> impl std::future::Future<Output = Result<(Self::SendSink, Self::RecvStream), Self::OpenError>>
86           + Send {
87        let inner = self.inner.open();
88        async move {
89            let (send, recv) = inner.await?;
90            Ok((MappedSendSink::new(send), MappedRecvStream::new(recv)))
91        }
92    }
93}
94
95/// A combinator that maps a stream of incoming messages to a different type
96#[pin_project]
97pub struct MappedRecvStream<S, In> {
98    inner: S,
99    _p: std::marker::PhantomData<In>,
100}
101
102impl<S, In> MappedRecvStream<S, In> {
103    /// Create a new mapped receive stream
104    pub fn new(inner: S) -> Self {
105        Self {
106            inner,
107            _p: std::marker::PhantomData,
108        }
109    }
110}
111
112/// Error mapping an incoming message to the inner type
113#[derive(Debug)]
114pub enum ErrorOrMapError<E> {
115    /// Error from the inner stream
116    Inner(E),
117    /// Conversion error
118    Conversion,
119}
120
121impl<E: Debug + Display> std::error::Error for ErrorOrMapError<E> {}
122
123impl<E: Display> Display for ErrorOrMapError<E> {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        match self {
126            ErrorOrMapError::Inner(e) => write!(f, "Inner error: {}", e),
127            ErrorOrMapError::Conversion => write!(f, "Conversion error"),
128        }
129    }
130}
131
132impl<S, In0, In, E> Stream for MappedRecvStream<S, In>
133where
134    S: Stream<Item = Result<In0, E>> + Unpin,
135    In: TryFrom<In0>,
136    E: RpcError,
137{
138    type Item = Result<In, ErrorOrMapError<E>>;
139
140    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
141        match self.project().inner.poll_next(cx) {
142            Poll::Ready(Some(Ok(item))) => {
143                let item = item.try_into().map_err(|_| ErrorOrMapError::Conversion);
144                Poll::Ready(Some(item))
145            }
146            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ErrorOrMapError::Inner(e)))),
147            Poll::Ready(None) => Poll::Ready(None),
148            Poll::Pending => Poll::Pending,
149        }
150    }
151}
152
153/// A sink that maps outgoing messages to a different type
154///
155/// The conversion to the underlying message type always succeeds, so this
156/// is relatively simple.
157#[pin_project]
158pub struct MappedSendSink<S, Out, OutS> {
159    inner: S,
160    _p: std::marker::PhantomData<(Out, OutS)>,
161}
162
163impl<S, Out, Out0> MappedSendSink<S, Out, Out0> {
164    /// Create a new mapped send sink
165    pub fn new(inner: S) -> Self {
166        Self {
167            inner,
168            _p: std::marker::PhantomData,
169        }
170    }
171}
172
173impl<S, Out, Out0> futures_sink::Sink<Out> for MappedSendSink<S, Out, Out0>
174where
175    S: futures_sink::Sink<Out0> + Unpin,
176    Out: Into<Out0>,
177{
178    type Error = S::Error;
179
180    fn poll_ready(
181        self: std::pin::Pin<&mut Self>,
182        cx: &mut Context,
183    ) -> Poll<Result<(), Self::Error>> {
184        self.project().inner.poll_ready_unpin(cx)
185    }
186
187    fn start_send(self: std::pin::Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
188        self.project().inner.start_send_unpin(item.into())
189    }
190
191    fn poll_flush(
192        self: std::pin::Pin<&mut Self>,
193        cx: &mut Context,
194    ) -> Poll<Result<(), Self::Error>> {
195        self.project().inner.poll_flush_unpin(cx)
196    }
197
198    fn poll_close(
199        self: std::pin::Pin<&mut Self>,
200        cx: &mut Context,
201    ) -> Poll<Result<(), Self::Error>> {
202        self.project().inner.poll_close_unpin(cx)
203    }
204}
205
206/// Connection types for a mapped connection
207pub struct MappedStreamTypes<In, Out, C>(PhantomData<(In, Out, C)>);
208
209impl<In, Out, C> Debug for MappedStreamTypes<In, Out, C> {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("MappedConnectionTypes").finish()
212    }
213}
214
215impl<In, Out, C> Clone for MappedStreamTypes<In, Out, C> {
216    fn clone(&self) -> Self {
217        Self(PhantomData)
218    }
219}
220
221impl<In, Out, C> ConnectionErrors for MappedStreamTypes<In, Out, C>
222where
223    In: RpcMessage,
224    Out: RpcMessage,
225    C: ConnectionErrors,
226{
227    type RecvError = ErrorOrMapError<C::RecvError>;
228    type SendError = C::SendError;
229    type OpenError = C::OpenError;
230    type AcceptError = C::AcceptError;
231}
232
233impl<In, Out, C> StreamTypes for MappedStreamTypes<In, Out, C>
234where
235    C: StreamTypes,
236    In: RpcMessage,
237    Out: RpcMessage,
238    In: TryFrom<C::In>,
239    C::Out: From<Out>,
240{
241    type In = In;
242    type Out = Out;
243    type RecvStream = MappedRecvStream<C::RecvStream, In>;
244    type SendSink = MappedSendSink<C::SendSink, Out, C::Out>;
245}
246
247#[cfg(test)]
248#[cfg(feature = "flume-transport")]
249mod tests {
250
251    use serde::{Deserialize, Serialize};
252    use testresult::TestResult;
253
254    use super::*;
255    use crate::{
256        server::{BoxedChannelTypes, RpcChannel},
257        transport::Listener,
258        RpcClient, RpcServer,
259    };
260
261    #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From, derive_more::TryInto)]
262    enum Request {
263        A(u64),
264        B(String),
265    }
266
267    #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From, derive_more::TryInto)]
268    enum Response {
269        A(u64),
270        B(String),
271    }
272
273    #[derive(Debug, Clone)]
274    struct FullService;
275
276    impl crate::Service for FullService {
277        type Req = Request;
278        type Res = Response;
279    }
280
281    #[derive(Debug, Clone)]
282    struct SubService;
283
284    impl crate::Service for SubService {
285        type Req = String;
286        type Res = String;
287    }
288
289    #[tokio::test]
290    #[ignore]
291    async fn smoke() -> TestResult<()> {
292        async fn handle_sub_request(
293            _req: String,
294            _chan: RpcChannel<SubService, BoxedChannelTypes<SubService>>,
295        ) -> anyhow::Result<()> {
296            Ok(())
297        }
298        // create a listener / connector pair. Type will be inferred
299        let (s, c) = crate::transport::flume::channel(32);
300        // wrap the server in a RpcServer, this is where the service type is specified
301        let server = RpcServer::<FullService, _>::new(s.clone());
302        // when using a boxed transport, we can omit the transport type and use the default
303        let _server_boxed: RpcServer<FullService> = RpcServer::<FullService>::new(s.boxed());
304        // create a client in a RpcClient, this is where the service type is specified
305        let client = RpcClient::<FullService, _>::new(c);
306        // when using a boxed transport, we can omit the transport type and use the default
307        let _boxed_client = client.clone().boxed();
308        // map the client to a sub-service
309        let _sub_client: RpcClient<SubService, _> = client.clone().map::<SubService>();
310        // when using a boxed transport, we can omit the transport type and use the default
311        let _sub_client_boxed: RpcClient<SubService> = client.clone().map::<SubService>().boxed();
312        // we can not map the service to a sub-service, since we need the first message to determine which sub-service to use
313        while let Ok(accepting) = server.accept().await {
314            let (msg, chan) = accepting.read_first().await?;
315            match msg {
316                Request::A(_x) => todo!(),
317                Request::B(x) => {
318                    // but we can map the channel to the sub-service, once we know which one to use
319                    handle_sub_request(x, chan.map::<SubService>().boxed()).await?
320                }
321            }
322        }
323        Ok(())
324    }
325}