1use core::fmt;
5use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, task::Poll};
6
7use futures_lite::{Future, Stream};
8use futures_sink::Sink;
9
10use super::StreamTypes;
11use crate::{
12 transport::{ConnectionErrors, Connector, Listener, LocalAddr},
13 RpcMessage,
14};
15
16#[derive(Debug)]
20pub enum RecvError {}
21
22impl fmt::Display for RecvError {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 fmt::Debug::fmt(self, f)
25 }
26}
27
28pub struct SendSink<T: RpcMessage>(pub(crate) flume::r#async::SendSink<'static, T>);
30
31impl<T: RpcMessage> fmt::Debug for SendSink<T> {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.debug_struct("SendSink").finish()
34 }
35}
36
37impl<T: RpcMessage> Sink<T> for SendSink<T> {
38 type Error = self::SendError;
39
40 fn poll_ready(
41 mut self: Pin<&mut Self>,
42 cx: &mut std::task::Context<'_>,
43 ) -> Poll<Result<(), Self::Error>> {
44 Pin::new(&mut self.0)
45 .poll_ready(cx)
46 .map_err(|_| SendError::ReceiverDropped)
47 }
48
49 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
50 Pin::new(&mut self.0)
51 .start_send(item)
52 .map_err(|_| SendError::ReceiverDropped)
53 }
54
55 fn poll_flush(
56 mut self: Pin<&mut Self>,
57 cx: &mut std::task::Context<'_>,
58 ) -> Poll<Result<(), Self::Error>> {
59 Pin::new(&mut self.0)
60 .poll_flush(cx)
61 .map_err(|_| SendError::ReceiverDropped)
62 }
63
64 fn poll_close(
65 mut self: Pin<&mut Self>,
66 cx: &mut std::task::Context<'_>,
67 ) -> Poll<Result<(), Self::Error>> {
68 Pin::new(&mut self.0)
69 .poll_close(cx)
70 .map_err(|_| SendError::ReceiverDropped)
71 }
72}
73
74pub struct RecvStream<T: RpcMessage>(pub(crate) flume::r#async::RecvStream<'static, T>);
76
77impl<T: RpcMessage> fmt::Debug for RecvStream<T> {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_struct("RecvStream").finish()
80 }
81}
82
83impl<T: RpcMessage> Stream for RecvStream<T> {
84 type Item = result::Result<T, self::RecvError>;
85
86 fn poll_next(
87 mut self: Pin<&mut Self>,
88 cx: &mut std::task::Context<'_>,
89 ) -> Poll<Option<Self::Item>> {
90 match Pin::new(&mut self.0).poll_next(cx) {
91 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
92 Poll::Ready(None) => Poll::Ready(None),
93 Poll::Pending => Poll::Pending,
94 }
95 }
96}
97
98impl error::Error for RecvError {}
99
100pub struct FlumeListener<In: RpcMessage, Out: RpcMessage> {
104 #[allow(clippy::type_complexity)]
105 stream: flume::Receiver<(SendSink<Out>, RecvStream<In>)>,
106}
107
108impl<In: RpcMessage, Out: RpcMessage> Clone for FlumeListener<In, Out> {
109 fn clone(&self) -> Self {
110 Self {
111 stream: self.stream.clone(),
112 }
113 }
114}
115
116impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for FlumeListener<In, Out> {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 f.debug_struct("FlumeListener")
119 .field("stream", &self.stream)
120 .finish()
121 }
122}
123
124impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for FlumeListener<In, Out> {
125 type SendError = self::SendError;
126 type RecvError = self::RecvError;
127 type OpenError = self::OpenError;
128 type AcceptError = self::AcceptError;
129}
130
131type Socket<In, Out> = (self::SendSink<Out>, self::RecvStream<In>);
132
133pub struct OpenFuture<In: RpcMessage, Out: RpcMessage> {
135 inner: flume::r#async::SendFut<'static, Socket<Out, In>>,
136 res: Option<Socket<In, Out>>,
137}
138
139impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for OpenFuture<In, Out> {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 f.debug_struct("OpenFuture").finish()
142 }
143}
144
145impl<In: RpcMessage, Out: RpcMessage> OpenFuture<In, Out> {
146 fn new(inner: flume::r#async::SendFut<'static, Socket<Out, In>>, res: Socket<In, Out>) -> Self {
147 Self {
148 inner,
149 res: Some(res),
150 }
151 }
152}
153
154impl<In: RpcMessage, Out: RpcMessage> Future for OpenFuture<In, Out> {
155 type Output = result::Result<Socket<In, Out>, self::OpenError>;
156
157 fn poll(
158 mut self: Pin<&mut Self>,
159 cx: &mut std::task::Context<'_>,
160 ) -> std::task::Poll<Self::Output> {
161 match Pin::new(&mut self.inner).poll(cx) {
162 Poll::Ready(Ok(())) => self
163 .res
164 .take()
165 .map(|x| Poll::Ready(Ok(x)))
166 .unwrap_or(Poll::Pending),
167 Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenError::RemoteDropped)),
168 Poll::Pending => Poll::Pending,
169 }
170 }
171}
172
173pub struct AcceptFuture<In: RpcMessage, Out: RpcMessage> {
175 wrapped: flume::r#async::RecvFut<'static, (SendSink<Out>, RecvStream<In>)>,
176 _p: PhantomData<(In, Out)>,
177}
178
179impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for AcceptFuture<In, Out> {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 f.debug_struct("AcceptFuture").finish()
182 }
183}
184
185impl<In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<In, Out> {
186 type Output = result::Result<(SendSink<Out>, RecvStream<In>), AcceptError>;
187
188 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
189 match Pin::new(&mut self.wrapped).poll(cx) {
190 Poll::Ready(Ok((send, recv))) => Poll::Ready(Ok((send, recv))),
191 Poll::Ready(Err(_)) => Poll::Ready(Err(AcceptError::RemoteDropped)),
192 Poll::Pending => Poll::Pending,
193 }
194 }
195}
196
197impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeListener<In, Out> {
198 type In = In;
199 type Out = Out;
200 type SendSink = SendSink<Out>;
201 type RecvStream = RecvStream<In>;
202}
203
204impl<In: RpcMessage, Out: RpcMessage> Listener for FlumeListener<In, Out> {
205 #[allow(refining_impl_trait)]
206 fn accept(&self) -> AcceptFuture<In, Out> {
207 AcceptFuture {
208 wrapped: self.stream.clone().into_recv_async(),
209 _p: PhantomData,
210 }
211 }
212
213 fn local_addr(&self) -> &[LocalAddr] {
214 &[LocalAddr::Mem]
215 }
216}
217
218impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for FlumeConnector<In, Out> {
219 type SendError = self::SendError;
220 type RecvError = self::RecvError;
221 type OpenError = self::OpenError;
222 type AcceptError = self::AcceptError;
223}
224
225impl<In: RpcMessage, Out: RpcMessage> StreamTypes for FlumeConnector<In, Out> {
226 type In = In;
227 type Out = Out;
228 type SendSink = SendSink<Out>;
229 type RecvStream = RecvStream<In>;
230}
231
232impl<In: RpcMessage, Out: RpcMessage> Connector for FlumeConnector<In, Out> {
233 #[allow(refining_impl_trait)]
234 fn open(&self) -> OpenFuture<In, Out> {
235 let (local_send, remote_recv) = flume::bounded::<Out>(128);
236 let (remote_send, local_recv) = flume::bounded::<In>(128);
237 let remote_chan = (
238 SendSink(remote_send.into_sink()),
239 RecvStream(remote_recv.into_stream()),
240 );
241 let local_chan = (
242 SendSink(local_send.into_sink()),
243 RecvStream(local_recv.into_stream()),
244 );
245 OpenFuture::new(self.sink.clone().into_send_async(remote_chan), local_chan)
246 }
247}
248
249pub struct FlumeConnector<In: RpcMessage, Out: RpcMessage> {
253 #[allow(clippy::type_complexity)]
254 sink: flume::Sender<(SendSink<In>, RecvStream<Out>)>,
255}
256
257impl<In: RpcMessage, Out: RpcMessage> Clone for FlumeConnector<In, Out> {
258 fn clone(&self) -> Self {
259 Self {
260 sink: self.sink.clone(),
261 }
262 }
263}
264
265impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for FlumeConnector<In, Out> {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 f.debug_struct("FlumeClientChannel")
268 .field("sink", &self.sink)
269 .finish()
270 }
271}
272
273#[derive(Debug)]
277pub enum AcceptError {
278 RemoteDropped,
280}
281
282impl fmt::Display for AcceptError {
283 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
284 fmt::Debug::fmt(self, f)
285 }
286}
287
288impl error::Error for AcceptError {}
289
290#[derive(Debug)]
294pub enum SendError {
295 ReceiverDropped,
297}
298
299impl Display for SendError {
300 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 fmt::Debug::fmt(self, f)
302 }
303}
304
305impl std::error::Error for SendError {}
306
307#[derive(Debug)]
309pub enum OpenError {
310 RemoteDropped,
312}
313
314impl Display for OpenError {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 fmt::Debug::fmt(self, f)
317 }
318}
319
320impl std::error::Error for OpenError {}
321
322#[derive(Debug, Clone, Copy)]
327pub enum CreateChannelError {}
328
329impl Display for CreateChannelError {
330 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331 fmt::Debug::fmt(self, f)
332 }
333}
334
335impl std::error::Error for CreateChannelError {}
336
337pub fn channel<Req: RpcMessage, Res: RpcMessage>(
341 buffer: usize,
342) -> (FlumeListener<Req, Res>, FlumeConnector<Res, Req>) {
343 let (sink, stream) = flume::bounded(buffer);
344 (FlumeListener { stream }, FlumeConnector { sink })
345}