quic_rpc/pattern/
server_streaming.rs

1//! Server streaming interaction pattern.
2
3use std::{
4    error,
5    fmt::{self, Debug},
6    result,
7};
8
9use futures_lite::{Stream, StreamExt};
10use futures_util::{FutureExt, SinkExt, TryFutureExt};
11
12use crate::{
13    client::{BoxStreamSync, DeferDrop},
14    message::{InteractionPattern, Msg},
15    server::{race2, RpcChannel, RpcServerError},
16    transport::{ConnectionErrors, Connector, StreamTypes},
17    RpcClient, Service,
18};
19
20/// Server streaming interaction pattern
21///
22/// After the initial request, the server can send a stream of responses.
23#[derive(Debug, Clone, Copy)]
24pub struct ServerStreaming;
25impl InteractionPattern for ServerStreaming {}
26
27/// Defines response type for a server streaming message.
28pub trait ServerStreamingMsg<S: Service>: Msg<S, Pattern = ServerStreaming> {
29    /// The type for the response
30    ///
31    /// For requests that can produce errors, this can be set to [Result<T, E>](std::result::Result).
32    type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
33}
34
35/// Server error when accepting a server streaming request
36#[derive(Debug)]
37pub enum Error<C: ConnectionErrors> {
38    /// Unable to open a substream at all
39    Open(C::OpenError),
40    /// Unable to send the request to the server
41    Send(C::SendError),
42}
43
44impl<S: Connector> fmt::Display for Error<S> {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        fmt::Debug::fmt(self, f)
47    }
48}
49
50impl<S: Connector> error::Error for Error<S> {}
51
52/// Client error when handling responses from a server streaming request
53#[derive(Debug)]
54pub enum ItemError<S: ConnectionErrors> {
55    /// Unable to receive the response from the server
56    RecvError(S::RecvError),
57    /// Unexpected response from the server
58    DowncastError,
59}
60
61impl<S: ConnectionErrors> fmt::Display for ItemError<S> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        fmt::Debug::fmt(self, f)
64    }
65}
66
67impl<S: ConnectionErrors> error::Error for ItemError<S> {}
68
69impl<S, C> RpcClient<S, C>
70where
71    C: crate::Connector<S>,
72    S: Service,
73{
74    /// Bidi call to the server, request opens a stream, response is a stream
75    pub async fn server_streaming<M>(
76        &self,
77        msg: M,
78    ) -> result::Result<BoxStreamSync<'static, result::Result<M::Response, ItemError<C>>>, Error<C>>
79    where
80        M: ServerStreamingMsg<S>,
81    {
82        let msg = msg.into();
83        let (mut send, recv) = self.source.open().await.map_err(Error::Open)?;
84        send.send(msg).map_err(Error::<C>::Send).await?;
85        let recv = recv.map(move |x| match x {
86            Ok(msg) => M::Response::try_from(msg).map_err(|_| ItemError::DowncastError),
87            Err(e) => Err(ItemError::RecvError(e)),
88        });
89        // keep send alive so the request on the server side does not get cancelled
90        let recv = Box::pin(DeferDrop(recv, send));
91        Ok(recv)
92    }
93}
94
95impl<S, C> RpcChannel<S, C>
96where
97    S: Service,
98    C: StreamTypes<In = S::Req, Out = S::Res>,
99{
100    /// handle the message M using the given function on the target object
101    ///
102    /// If you want to support concurrent requests, you need to spawn this on a tokio task yourself.
103    pub async fn server_streaming<M, F, Str, T>(
104        self,
105        req: M,
106        target: T,
107        f: F,
108    ) -> result::Result<(), RpcServerError<C>>
109    where
110        M: ServerStreamingMsg<S>,
111        F: FnOnce(T, M) -> Str + Send + 'static,
112        Str: Stream<Item = M::Response> + Send + 'static,
113        T: Send + 'static,
114    {
115        let Self {
116            mut send, mut recv, ..
117        } = self;
118        // cancel if we get an update, no matter what it is
119        let cancel = recv
120            .next()
121            .map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
122        // race the computation and the cancellation
123        race2(cancel.map(Err), async move {
124            // get the response
125            let responses = f(target, req);
126            tokio::pin!(responses);
127            while let Some(response) = responses.next().await {
128                // turn into a S::Res so we can send it
129                let response = response.into();
130                // send it and return the error if any
131                send.send(response)
132                    .await
133                    .map_err(RpcServerError::SendError)?;
134            }
135            Ok(())
136        })
137        .await
138    }
139}