quic_rpc/
server.rs

1//! Server side api
2//!
3//! The main entry point is [RpcServer]
4use std::{
5    error,
6    fmt::{self, Debug},
7    marker::PhantomData,
8    pin::Pin,
9    result,
10    sync::Arc,
11    task::{self, Poll},
12};
13
14use futures_lite::{Future, Stream, StreamExt};
15use futures_util::{SinkExt, TryStreamExt};
16use pin_project::pin_project;
17use tokio::{sync::oneshot, task::JoinSet};
18use tokio_util::task::AbortOnDropHandle;
19use tracing::{error, warn};
20
21use crate::{
22    transport::{
23        self,
24        boxed::BoxableListener,
25        mapped::{ErrorOrMapError, MappedRecvStream, MappedSendSink, MappedStreamTypes},
26        ConnectionErrors, StreamTypes,
27    },
28    Listener, RpcMessage, Service,
29};
30
31/// Stream types on the server side
32///
33/// On the server side, we receive requests and send responses.
34/// On the client side, we send requests and receive responses.
35pub trait ChannelTypes<S: Service>: transport::StreamTypes<In = S::Req, Out = S::Res> {}
36
37impl<T: transport::StreamTypes<In = S::Req, Out = S::Res>, S: Service> ChannelTypes<S> for T {}
38
39/// Type alias for when you want to require a boxed channel
40pub type BoxedChannelTypes<S> = crate::transport::boxed::BoxedStreamTypes<
41    <S as crate::Service>::Req,
42    <S as crate::Service>::Res,
43>;
44
45/// A boxed listener for the given [`Service`]
46pub type BoxedListener<S> =
47    crate::transport::boxed::BoxedListener<<S as crate::Service>::Req, <S as crate::Service>::Res>;
48
49#[cfg(feature = "flume-transport")]
50#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "flume-transport")))]
51/// A flume listener for the given [`Service`]
52pub type FlumeListener<S> =
53    crate::transport::flume::FlumeListener<<S as Service>::Req, <S as Service>::Res>;
54
55#[cfg(feature = "quinn-transport")]
56#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "quinn-transport")))]
57/// A quinn listener for the given [`Service`]
58pub type QuinnListener<S> =
59    crate::transport::quinn::QuinnListener<<S as Service>::Req, <S as Service>::Res>;
60
61#[cfg(feature = "hyper-transport")]
62#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "hyper-transport")))]
63/// A hyper listener for the given [`Service`]
64pub type HyperListener<S> =
65    crate::transport::hyper::HyperListener<<S as Service>::Req, <S as Service>::Res>;
66
67#[cfg(feature = "iroh-transport")]
68#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "iroh-transport")))]
69/// An iroh listener for the given [`Service`]
70pub type IrohListener<S> =
71    crate::transport::iroh::IrohListener<<S as Service>::Req, <S as Service>::Res>;
72
73/// A server for a specific service.
74///
75/// This is a wrapper around a [`Listener`] that serves as the entry point for the server DSL.
76///
77/// Type parameters:
78///
79/// `S` is the service type.
80/// `C` is the channel type.
81#[derive(Debug)]
82pub struct RpcServer<S, C = BoxedListener<S>> {
83    /// The channel on which new requests arrive.
84    ///
85    /// Each new request is a receiver and channel pair on which messages for this request
86    /// are received and responses sent.
87    source: C,
88    _p: PhantomData<S>,
89}
90
91impl<S, C: Clone> Clone for RpcServer<S, C> {
92    fn clone(&self) -> Self {
93        Self {
94            source: self.source.clone(),
95            _p: PhantomData,
96        }
97    }
98}
99
100impl<S: Service, C: Listener<S>> RpcServer<S, C> {
101    /// Create a new rpc server for a specific service for a [Service] given a compatible
102    /// [Listener].
103    ///
104    /// This is where a generic typed endpoint is converted into a server for a specific service.
105    pub fn new(source: C) -> Self {
106        Self {
107            source,
108            _p: PhantomData,
109        }
110    }
111
112    /// Box the transport for the service.
113    ///
114    /// The boxed transport is the default for the `C` type parameter, so by boxing we can avoid
115    /// having to specify the type parameter.
116    pub fn boxed(self) -> RpcServer<S, BoxedListener<S>>
117    where
118        C: BoxableListener<S::Req, S::Res>,
119    {
120        RpcServer::new(self.source.boxed())
121    }
122}
123
124/// A channel for requests and responses for a specific service.
125///
126/// This just groups the sink and stream into a single type, and attaches the
127/// information about the service type.
128///
129/// Sink and stream are independent, so you can take the channel apart and use
130/// them independently.
131///
132/// Type parameters:
133///
134/// `S` is the service type.
135/// `C` is the service endpoint from which the channel was created.
136#[derive(Debug)]
137pub struct RpcChannel<S: Service, C: ChannelTypes<S> = BoxedChannelTypes<S>> {
138    /// Sink to send responses to the client.
139    pub send: C::SendSink,
140    /// Stream to receive requests from the client.
141    pub recv: C::RecvStream,
142
143    pub(crate) _p: PhantomData<S>,
144}
145
146impl<S, C> RpcChannel<S, C>
147where
148    S: Service,
149    C: StreamTypes<In = S::Req, Out = S::Res>,
150{
151    /// Create a new RPC channel.
152    pub fn new(send: C::SendSink, recv: C::RecvStream) -> Self {
153        Self {
154            send,
155            recv,
156            _p: PhantomData,
157        }
158    }
159
160    /// Convert this channel into a boxed channel.
161    pub fn boxed(self) -> RpcChannel<S, BoxedChannelTypes<S>>
162    where
163        C::SendError: Into<anyhow::Error> + Send + Sync + 'static,
164        C::RecvError: Into<anyhow::Error> + Send + Sync + 'static,
165    {
166        let send =
167            transport::boxed::SendSink::boxed(Box::new(self.send.sink_map_err(|e| e.into())));
168        let recv = transport::boxed::RecvStream::boxed(Box::new(self.recv.map_err(|e| e.into())));
169        RpcChannel::new(send, recv)
170    }
171
172    /// Map this channel's service into an inner service.
173    ///
174    /// This method is available if the required bounds are upheld:
175    /// SNext::Req: Into<S::Req> + TryFrom<S::Req>,
176    /// SNext::Res: Into<S::Res> + TryFrom<S::Res>,
177    ///
178    /// Where SNext is the new service to map to and S is the current inner service.
179    ///
180    /// This method can be chained infintely.
181    pub fn map<SNext>(self) -> RpcChannel<SNext, MappedStreamTypes<SNext::Req, SNext::Res, C>>
182    where
183        SNext: Service,
184        SNext::Req: TryFrom<S::Req>,
185        S::Res: From<SNext::Res>,
186    {
187        RpcChannel::new(
188            MappedSendSink::new(self.send),
189            MappedRecvStream::new(self.recv),
190        )
191    }
192}
193
194/// The result of accepting a new connection.
195pub struct Accepting<S: Service, C: Listener<S>> {
196    send: C::SendSink,
197    recv: C::RecvStream,
198    _p: PhantomData<S>,
199}
200
201impl<S: Service, C: Listener<S>> Accepting<S, C> {
202    /// Read the first message from the client.
203    ///
204    /// The return value is a tuple of `(request, channel)`.  Here `request` is the
205    /// first request which is already read from the stream.  The `channel` is a
206    /// [RpcChannel] that has `sink` and `stream` fields that can be used to send more
207    /// requests and/or receive more responses.
208    ///
209    /// Often sink and stream will wrap an an underlying byte stream. In this case you can
210    /// call into_inner() on them to get it back to perform byte level reads and writes.
211    pub async fn read_first(self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
212        let Accepting { send, mut recv, .. } = self;
213        // get the first message from the client. This will tell us what it wants to do.
214        let request: S::Req = recv
215            .next()
216            .await
217            // no msg => early close
218            .ok_or(RpcServerError::EarlyClose)?
219            // recv error
220            .map_err(RpcServerError::RecvError)?;
221        Ok((request, RpcChannel::<S, C>::new(send, recv)))
222    }
223}
224
225impl<S: Service, C: Listener<S>> RpcServer<S, C> {
226    /// Accepts a new channel from a client. The result is an [Accepting] object that
227    /// can be used to read the first request.
228    pub async fn accept(&self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
229        let (send, recv) = self.source.accept().await.map_err(RpcServerError::Accept)?;
230        Ok(Accepting {
231            send,
232            recv,
233            _p: PhantomData,
234        })
235    }
236
237    /// Get the underlying service endpoint
238    pub fn into_inner(self) -> C {
239        self.source
240    }
241
242    /// Run an accept loop for this server.
243    ///
244    /// Each request will be handled in a separate task.
245    ///
246    /// It is the caller's responsibility to poll the returned future to drive the server.
247    pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
248    where
249        S: Service,
250        C: Listener<S>,
251        Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
252        Fut: Future<Output = Result<(), E>> + Send + 'static,
253        E: Into<anyhow::Error> + 'static,
254    {
255        let handler = Arc::new(handler);
256        let mut tasks = JoinSet::new();
257        loop {
258            tokio::select! {
259                Some(res) = tasks.join_next(), if !tasks.is_empty() => {
260                    if let Err(e) = res {
261                        if e.is_panic() {
262                            error!("Panic handling RPC request: {e}");
263                        }
264                    }
265                }
266                req = self.accept() => {
267                    let req = match req {
268                        Ok(req) => req,
269                        Err(e) => {
270                            warn!("Error accepting RPC request: {e}");
271                            continue;
272                        }
273                    };
274                    let handler = handler.clone();
275                    tasks.spawn(async move {
276                        let (req, chan) = match req.read_first().await {
277                            Ok((req, chan)) => (req, chan),
278                            Err(e) => {
279                                warn!("Error reading first message: {e}");
280                                return;
281                            }
282                        };
283                        if let Err(cause) = handler(req, chan).await {
284                            warn!("Error handling RPC request: {}", cause.into());
285                        }
286                    });
287                }
288            }
289        }
290    }
291
292    /// Spawn an accept loop and return a handle to the task.
293    pub fn spawn_accept_loop<Fun, Fut, E>(self, handler: Fun) -> AbortOnDropHandle<()>
294    where
295        S: Service,
296        C: Listener<S>,
297        Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
298        Fut: Future<Output = Result<(), E>> + Send + 'static,
299        E: Into<anyhow::Error> + 'static,
300    {
301        AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler)))
302    }
303}
304
305impl<S: Service, C: Listener<S>> AsRef<C> for RpcServer<S, C> {
306    fn as_ref(&self) -> &C {
307        &self.source
308    }
309}
310
311/// A stream of updates
312///
313/// If there is any error with receiving or with decoding the updates, the stream will stall and the error will
314/// cause a termination of the RPC call.
315#[pin_project]
316#[derive(Debug)]
317pub struct UpdateStream<C, T>(
318    #[pin] C::RecvStream,
319    Option<oneshot::Sender<RpcServerError<C>>>,
320    PhantomData<T>,
321)
322where
323    C: StreamTypes;
324
325impl<C, T> UpdateStream<C, T>
326where
327    C: StreamTypes,
328    T: TryFrom<C::In>,
329{
330    pub(crate) fn new(recv: C::RecvStream) -> (Self, UnwrapToPending<RpcServerError<C>>) {
331        let (error_send, error_recv) = oneshot::channel();
332        let error_recv = UnwrapToPending(futures_lite::future::fuse(error_recv));
333        (Self(recv, Some(error_send), PhantomData), error_recv)
334    }
335}
336
337impl<C, T> Stream for UpdateStream<C, T>
338where
339    C: StreamTypes,
340    T: TryFrom<C::In>,
341{
342    type Item = T;
343
344    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
345        let mut this = self.project();
346        match Pin::new(&mut this.0).poll_next(cx) {
347            Poll::Ready(Some(msg)) => match msg {
348                Ok(msg) => {
349                    let msg = T::try_from(msg).map_err(|_cause| ());
350                    match msg {
351                        Ok(msg) => Poll::Ready(Some(msg)),
352                        Err(_cause) => {
353                            // we were unable to downcast, so we need to send an error
354                            if let Some(tx) = this.1.take() {
355                                let _ = tx.send(RpcServerError::UnexpectedUpdateMessage);
356                            }
357                            Poll::Pending
358                        }
359                    }
360                }
361                Err(cause) => {
362                    // we got a recv error, so return pending and send the error
363                    if let Some(tx) = this.1.take() {
364                        let _ = tx.send(RpcServerError::RecvError(cause));
365                    }
366                    Poll::Pending
367                }
368            },
369            Poll::Ready(None) => Poll::Ready(None),
370            Poll::Pending => Poll::Pending,
371        }
372    }
373}
374
375/// Server error. All server DSL methods return a `Result` with this error type.
376pub enum RpcServerError<C: ConnectionErrors> {
377    /// Unable to open a new channel
378    Accept(C::AcceptError),
379    /// Recv side for a channel was closed before getting the first message
380    EarlyClose,
381    /// Got an unexpected first message, e.g. an update message
382    UnexpectedStartMessage,
383    /// Error receiving a message
384    RecvError(C::RecvError),
385    /// Error sending a response
386    SendError(C::SendError),
387    /// Got an unexpected update message, e.g. a request message or a non-matching update message
388    UnexpectedUpdateMessage,
389}
390
391impl<In: RpcMessage, Out: RpcMessage, C: ConnectionErrors>
392    RpcServerError<MappedStreamTypes<In, Out, C>>
393{
394    /// For a mapped connection, map the error back to the original error type
395    pub fn map_back(self) -> RpcServerError<C> {
396        match self {
397            RpcServerError::EarlyClose => RpcServerError::EarlyClose,
398            RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage,
399            RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage,
400            RpcServerError::SendError(x) => RpcServerError::SendError(x),
401            RpcServerError::Accept(x) => RpcServerError::Accept(x),
402            RpcServerError::RecvError(ErrorOrMapError::Inner(x)) => RpcServerError::RecvError(x),
403            RpcServerError::RecvError(ErrorOrMapError::Conversion) => {
404                RpcServerError::UnexpectedUpdateMessage
405            }
406        }
407    }
408}
409
410impl<C: ConnectionErrors> RpcServerError<C> {
411    /// Convert into a different error type provided the send, recv and accept errors can be converted
412    pub fn errors_into<T>(self) -> RpcServerError<T>
413    where
414        T: ConnectionErrors,
415        C::SendError: Into<T::SendError>,
416        C::RecvError: Into<T::RecvError>,
417        C::AcceptError: Into<T::AcceptError>,
418    {
419        match self {
420            RpcServerError::EarlyClose => RpcServerError::EarlyClose,
421            RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage,
422            RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage,
423            RpcServerError::SendError(x) => RpcServerError::SendError(x.into()),
424            RpcServerError::Accept(x) => RpcServerError::Accept(x.into()),
425            RpcServerError::RecvError(x) => RpcServerError::RecvError(x.into()),
426        }
427    }
428}
429
430impl<C: ConnectionErrors> fmt::Debug for RpcServerError<C> {
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        match self {
433            Self::Accept(arg0) => f.debug_tuple("Open").field(arg0).finish(),
434            Self::EarlyClose => write!(f, "EarlyClose"),
435            Self::RecvError(arg0) => f.debug_tuple("RecvError").field(arg0).finish(),
436            Self::SendError(arg0) => f.debug_tuple("SendError").field(arg0).finish(),
437            Self::UnexpectedStartMessage => f.debug_tuple("UnexpectedStartMessage").finish(),
438            Self::UnexpectedUpdateMessage => f.debug_tuple("UnexpectedStartMessage").finish(),
439        }
440    }
441}
442
443impl<C: ConnectionErrors> fmt::Display for RpcServerError<C> {
444    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445        fmt::Debug::fmt(&self, f)
446    }
447}
448
449impl<C: ConnectionErrors> error::Error for RpcServerError<C> {}
450
451/// Take an oneshot receiver and just return Pending the underlying future returns `Err(oneshot::Canceled)`
452pub(crate) struct UnwrapToPending<T>(futures_lite::future::Fuse<oneshot::Receiver<T>>);
453
454impl<T> Future for UnwrapToPending<T> {
455    type Output = T;
456
457    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
458        // todo: use is_terminated from tokio 1.44 here to avoid the fused wrapper
459        match Pin::new(&mut self.0).poll(cx) {
460            Poll::Ready(Ok(x)) => Poll::Ready(x),
461            Poll::Ready(Err(_)) => Poll::Pending,
462            Poll::Pending => Poll::Pending,
463        }
464    }
465}
466
467pub(crate) async fn race2<T, A: Future<Output = T>, B: Future<Output = T>>(f1: A, f2: B) -> T {
468    tokio::select! {
469        x = f1 => x,
470        x = f2 => x,
471    }
472}
473
474/// Run a server loop, invoking a handler callback for each request.
475///
476/// Requests will be handled sequentially.
477pub async fn run_server_loop<S, C, T, F, Fut>(
478    _service_type: S,
479    conn: C,
480    target: T,
481    mut handler: F,
482) -> Result<(), RpcServerError<C>>
483where
484    S: Service,
485    C: Listener<S>,
486    T: Clone + Send + 'static,
487    F: FnMut(RpcChannel<S, C>, S::Req, T) -> Fut + Send + 'static,
488    Fut: Future<Output = Result<(), RpcServerError<C>>> + Send + 'static,
489{
490    let server: RpcServer<S, C> = RpcServer::<S, C>::new(conn);
491    loop {
492        let (req, chan) = server.accept().await?.read_first().await?;
493        let target = target.clone();
494        handler(chan, req, target).await?;
495    }
496}