quic_rpc/pattern/
bidi_streaming.rs

1//! Bidirectional stream interaction pattern.
2
3use std::{
4    error,
5    fmt::{self, Debug},
6    result,
7};
8
9use futures_lite::{Stream, StreamExt};
10use futures_util::{FutureExt, SinkExt};
11
12use crate::{
13    client::{BoxStreamSync, UpdateSink},
14    message::{InteractionPattern, Msg},
15    server::{race2, RpcChannel, RpcServerError, UpdateStream},
16    transport::{ConnectionErrors, Connector, StreamTypes},
17    RpcClient, Service,
18};
19
20/// Bidirectional streaming interaction pattern
21///
22/// After the initial request, the client can send updates and the server can
23/// send responses.
24#[derive(Debug, Clone, Copy)]
25pub struct BidiStreaming;
26impl InteractionPattern for BidiStreaming {}
27
28/// Defines update type and response type for a bidi streaming message.
29pub trait BidiStreamingMsg<S: Service>: Msg<S, Pattern = BidiStreaming> {
30    /// The type for request updates
31    ///
32    /// For a request that does not support updates, this can be safely set to any type, including
33    /// the message type itself. Any update for such a request will result in an error.
34    type Update: Into<S::Req> + TryFrom<S::Req> + Send + 'static;
35
36    /// The type for the response
37    ///
38    /// For requests that can produce errors, this can be set to [Result<T, E>](std::result::Result).
39    type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
40}
41
42/// Server error when accepting a bidi request
43#[derive(Debug)]
44pub enum Error<C: ConnectionErrors> {
45    /// Unable to open a substream at all
46    Open(C::OpenError),
47    /// Unable to send the request to the server
48    Send(C::SendError),
49}
50
51impl<C: ConnectionErrors> fmt::Display for Error<C> {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        fmt::Debug::fmt(self, f)
54    }
55}
56
57impl<C: ConnectionErrors> error::Error for Error<C> {}
58
59/// Server error when receiving an item for a bidi request
60#[derive(Debug)]
61pub enum ItemError<C: ConnectionErrors> {
62    /// Unable to receive the response from the server
63    RecvError(C::RecvError),
64    /// Unexpected response from the server
65    DowncastError,
66}
67
68impl<C: ConnectionErrors> fmt::Display for ItemError<C> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        fmt::Debug::fmt(self, f)
71    }
72}
73
74impl<C: ConnectionErrors> error::Error for ItemError<C> {}
75
76impl<S, C> RpcClient<S, C>
77where
78    S: Service,
79    C: Connector<In = S::Res, Out = S::Req>,
80{
81    /// Bidi call to the server, request opens a stream, response is a stream
82    pub async fn bidi<M>(
83        &self,
84        msg: M,
85    ) -> result::Result<
86        (
87            UpdateSink<C, M::Update>,
88            BoxStreamSync<'static, result::Result<M::Response, ItemError<C>>>,
89        ),
90        Error<C>,
91    >
92    where
93        M: BidiStreamingMsg<S>,
94    {
95        let msg = msg.into();
96        let (mut send, recv) = self.source.open().await.map_err(Error::Open)?;
97        send.send(msg).await.map_err(Error::<C>::Send)?;
98        let send = UpdateSink::new(send);
99        let recv = Box::pin(recv.map(move |x| match x {
100            Ok(msg) => M::Response::try_from(msg).map_err(|_| ItemError::DowncastError),
101            Err(e) => Err(ItemError::RecvError(e)),
102        }));
103        Ok((send, recv))
104    }
105}
106
107impl<C, S> RpcChannel<S, C>
108where
109    C: StreamTypes<In = S::Req, Out = S::Res>,
110    S: Service,
111{
112    /// handle the message M using the given function on the target object
113    ///
114    /// If you want to support concurrent requests, you need to spawn this on a tokio task yourself.
115    pub async fn bidi_streaming<M, F, Str, T>(
116        self,
117        req: M,
118        target: T,
119        f: F,
120    ) -> result::Result<(), RpcServerError<C>>
121    where
122        M: BidiStreamingMsg<S>,
123        F: FnOnce(T, M, UpdateStream<C, M::Update>) -> Str + Send + 'static,
124        Str: Stream<Item = M::Response> + Send + 'static,
125        T: Send + 'static,
126    {
127        let Self { mut send, recv, .. } = self;
128        // downcast the updates
129        let (updates, read_error) = UpdateStream::new(recv);
130        // get the response
131        let responses = f(target, req, updates);
132        race2(read_error.map(Err), async move {
133            tokio::pin!(responses);
134            while let Some(response) = responses.next().await {
135                // turn into a S::Res so we can send it
136                let response = response.into();
137                // send it and return the error if any
138                send.send(response)
139                    .await
140                    .map_err(RpcServerError::SendError)?;
141            }
142            Ok(())
143        })
144        .await
145    }
146}