1use std::{
3 fmt::{Debug, Display},
4 marker::PhantomData,
5 task::{Context, Poll},
6};
7
8use futures_lite::{Stream, StreamExt};
9use futures_util::SinkExt;
10use pin_project::pin_project;
11
12use super::{ConnectionErrors, Connector, StreamTypes};
13use crate::{RpcError, RpcMessage};
14
15#[derive(Debug)]
17pub struct MappedConnector<In, Out, C> {
18 inner: C,
19 _p: std::marker::PhantomData<(In, Out)>,
20}
21
22impl<In, Out, C> MappedConnector<In, Out, C>
23where
24 C: Connector,
25 In: TryFrom<C::In>,
26 C::Out: From<Out>,
27{
28 pub fn new(inner: C) -> Self {
30 Self {
31 inner,
32 _p: std::marker::PhantomData,
33 }
34 }
35}
36
37impl<In, Out, C> Clone for MappedConnector<In, Out, C>
38where
39 C: Clone,
40{
41 fn clone(&self) -> Self {
42 Self {
43 inner: self.inner.clone(),
44 _p: std::marker::PhantomData,
45 }
46 }
47}
48
49impl<In, Out, C> ConnectionErrors for MappedConnector<In, Out, C>
50where
51 In: RpcMessage,
52 Out: RpcMessage,
53 C: ConnectionErrors,
54{
55 type RecvError = ErrorOrMapError<C::RecvError>;
56 type SendError = C::SendError;
57 type OpenError = C::OpenError;
58 type AcceptError = C::AcceptError;
59}
60
61impl<In, Out, C> StreamTypes for MappedConnector<In, Out, C>
62where
63 C: StreamTypes,
64 In: RpcMessage,
65 Out: RpcMessage,
66 In: TryFrom<C::In>,
67 C::Out: From<Out>,
68{
69 type In = In;
70 type Out = Out;
71 type RecvStream = MappedRecvStream<C::RecvStream, In>;
72 type SendSink = MappedSendSink<C::SendSink, Out, C::Out>;
73}
74
75impl<In, Out, C> Connector for MappedConnector<In, Out, C>
76where
77 C: Connector,
78 In: RpcMessage,
79 Out: RpcMessage,
80 In: TryFrom<C::In>,
81 C::Out: From<Out>,
82{
83 fn open(
84 &self,
85 ) -> impl std::future::Future<Output = Result<(Self::SendSink, Self::RecvStream), Self::OpenError>>
86 + Send {
87 let inner = self.inner.open();
88 async move {
89 let (send, recv) = inner.await?;
90 Ok((MappedSendSink::new(send), MappedRecvStream::new(recv)))
91 }
92 }
93}
94
95#[pin_project]
97pub struct MappedRecvStream<S, In> {
98 inner: S,
99 _p: std::marker::PhantomData<In>,
100}
101
102impl<S, In> MappedRecvStream<S, In> {
103 pub fn new(inner: S) -> Self {
105 Self {
106 inner,
107 _p: std::marker::PhantomData,
108 }
109 }
110}
111
112#[derive(Debug)]
114pub enum ErrorOrMapError<E> {
115 Inner(E),
117 Conversion,
119}
120
121impl<E: Debug + Display> std::error::Error for ErrorOrMapError<E> {}
122
123impl<E: Display> Display for ErrorOrMapError<E> {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 match self {
126 ErrorOrMapError::Inner(e) => write!(f, "Inner error: {}", e),
127 ErrorOrMapError::Conversion => write!(f, "Conversion error"),
128 }
129 }
130}
131
132impl<S, In0, In, E> Stream for MappedRecvStream<S, In>
133where
134 S: Stream<Item = Result<In0, E>> + Unpin,
135 In: TryFrom<In0>,
136 E: RpcError,
137{
138 type Item = Result<In, ErrorOrMapError<E>>;
139
140 fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
141 match self.project().inner.poll_next(cx) {
142 Poll::Ready(Some(Ok(item))) => {
143 let item = item.try_into().map_err(|_| ErrorOrMapError::Conversion);
144 Poll::Ready(Some(item))
145 }
146 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(ErrorOrMapError::Inner(e)))),
147 Poll::Ready(None) => Poll::Ready(None),
148 Poll::Pending => Poll::Pending,
149 }
150 }
151}
152
153#[pin_project]
158pub struct MappedSendSink<S, Out, OutS> {
159 inner: S,
160 _p: std::marker::PhantomData<(Out, OutS)>,
161}
162
163impl<S, Out, Out0> MappedSendSink<S, Out, Out0> {
164 pub fn new(inner: S) -> Self {
166 Self {
167 inner,
168 _p: std::marker::PhantomData,
169 }
170 }
171}
172
173impl<S, Out, Out0> futures_sink::Sink<Out> for MappedSendSink<S, Out, Out0>
174where
175 S: futures_sink::Sink<Out0> + Unpin,
176 Out: Into<Out0>,
177{
178 type Error = S::Error;
179
180 fn poll_ready(
181 self: std::pin::Pin<&mut Self>,
182 cx: &mut Context,
183 ) -> Poll<Result<(), Self::Error>> {
184 self.project().inner.poll_ready_unpin(cx)
185 }
186
187 fn start_send(self: std::pin::Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
188 self.project().inner.start_send_unpin(item.into())
189 }
190
191 fn poll_flush(
192 self: std::pin::Pin<&mut Self>,
193 cx: &mut Context,
194 ) -> Poll<Result<(), Self::Error>> {
195 self.project().inner.poll_flush_unpin(cx)
196 }
197
198 fn poll_close(
199 self: std::pin::Pin<&mut Self>,
200 cx: &mut Context,
201 ) -> Poll<Result<(), Self::Error>> {
202 self.project().inner.poll_close_unpin(cx)
203 }
204}
205
206pub struct MappedStreamTypes<In, Out, C>(PhantomData<(In, Out, C)>);
208
209impl<In, Out, C> Debug for MappedStreamTypes<In, Out, C> {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("MappedConnectionTypes").finish()
212 }
213}
214
215impl<In, Out, C> Clone for MappedStreamTypes<In, Out, C> {
216 fn clone(&self) -> Self {
217 Self(PhantomData)
218 }
219}
220
221impl<In, Out, C> ConnectionErrors for MappedStreamTypes<In, Out, C>
222where
223 In: RpcMessage,
224 Out: RpcMessage,
225 C: ConnectionErrors,
226{
227 type RecvError = ErrorOrMapError<C::RecvError>;
228 type SendError = C::SendError;
229 type OpenError = C::OpenError;
230 type AcceptError = C::AcceptError;
231}
232
233impl<In, Out, C> StreamTypes for MappedStreamTypes<In, Out, C>
234where
235 C: StreamTypes,
236 In: RpcMessage,
237 Out: RpcMessage,
238 In: TryFrom<C::In>,
239 C::Out: From<Out>,
240{
241 type In = In;
242 type Out = Out;
243 type RecvStream = MappedRecvStream<C::RecvStream, In>;
244 type SendSink = MappedSendSink<C::SendSink, Out, C::Out>;
245}
246
247#[cfg(test)]
248#[cfg(feature = "flume-transport")]
249mod tests {
250
251 use serde::{Deserialize, Serialize};
252 use testresult::TestResult;
253
254 use super::*;
255 use crate::{
256 server::{BoxedChannelTypes, RpcChannel},
257 transport::Listener,
258 RpcClient, RpcServer,
259 };
260
261 #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From, derive_more::TryInto)]
262 enum Request {
263 A(u64),
264 B(String),
265 }
266
267 #[derive(Debug, Clone, Serialize, Deserialize, derive_more::From, derive_more::TryInto)]
268 enum Response {
269 A(u64),
270 B(String),
271 }
272
273 #[derive(Debug, Clone)]
274 struct FullService;
275
276 impl crate::Service for FullService {
277 type Req = Request;
278 type Res = Response;
279 }
280
281 #[derive(Debug, Clone)]
282 struct SubService;
283
284 impl crate::Service for SubService {
285 type Req = String;
286 type Res = String;
287 }
288
289 #[tokio::test]
290 #[ignore]
291 async fn smoke() -> TestResult<()> {
292 async fn handle_sub_request(
293 _req: String,
294 _chan: RpcChannel<SubService, BoxedChannelTypes<SubService>>,
295 ) -> anyhow::Result<()> {
296 Ok(())
297 }
298 let (s, c) = crate::transport::flume::channel(32);
300 let server = RpcServer::<FullService, _>::new(s.clone());
302 let _server_boxed: RpcServer<FullService> = RpcServer::<FullService>::new(s.boxed());
304 let client = RpcClient::<FullService, _>::new(c);
306 let _boxed_client = client.clone().boxed();
308 let _sub_client: RpcClient<SubService, _> = client.clone().map::<SubService>();
310 let _sub_client_boxed: RpcClient<SubService> = client.clone().map::<SubService>().boxed();
312 while let Ok(accepting) = server.accept().await {
314 let (msg, chan) = accepting.read_first().await?;
315 match msg {
316 Request::A(_x) => todo!(),
317 Request::B(x) => {
318 handle_sub_request(x, chan.map::<SubService>().boxed()).await?
320 }
321 }
322 }
323 Ok(())
324 }
325}