quic_rpc/pattern/
rpc.rs

1//! RPC interaction pattern.
2
3use std::{
4    error,
5    fmt::{self, Debug},
6    result,
7};
8
9use futures_lite::{Future, StreamExt};
10use futures_util::{FutureExt, SinkExt};
11
12use crate::{
13    message::{InteractionPattern, Msg},
14    server::{race2, RpcChannel, RpcServerError},
15    transport::{ConnectionErrors, StreamTypes},
16    Connector, RpcClient, Service,
17};
18
19/// Rpc interaction pattern
20///
21/// There is only one request and one response.
22#[derive(Debug, Clone, Copy)]
23pub struct Rpc;
24impl InteractionPattern for Rpc {}
25
26/// Defines the response type for a rpc message.
27///
28/// Since this is the most common interaction pattern, this also implements [Msg] for you
29/// automatically, with the interaction pattern set to [Rpc]. This is to reduce boilerplate
30/// when defining rpc messages.
31pub trait RpcMsg<S: Service>: Msg<S, Pattern = Rpc> {
32    /// The type for the response
33    ///
34    /// For requests that can produce errors, this can be set to [Result<T, E>](std::result::Result).
35    type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
36}
37
38/// We can only do this for one trait, so we do it for RpcMsg since it is the most common
39impl<T: RpcMsg<S>, S: Service> Msg<S> for T {
40    type Pattern = Rpc;
41}
42/// Client error. All client DSL methods return a `Result` with this error type.
43#[derive(Debug)]
44pub enum Error<C: ConnectionErrors> {
45    /// Unable to open a substream at all
46    Open(C::OpenError),
47    /// Unable to send the request to the server
48    Send(C::SendError),
49    /// Server closed the stream before sending a response
50    EarlyClose,
51    /// Unable to receive the response from the server
52    RecvError(C::RecvError),
53    /// Unexpected response from the server
54    DowncastError,
55}
56
57impl<C: ConnectionErrors> fmt::Display for Error<C> {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        fmt::Debug::fmt(self, f)
60    }
61}
62
63impl<C: ConnectionErrors> error::Error for Error<C> {}
64
65impl<S, C> RpcClient<S, C>
66where
67    S: Service,
68    C: Connector<S>,
69{
70    /// RPC call to the server, single request, single response
71    pub async fn rpc<M>(&self, msg: M) -> result::Result<M::Response, Error<C>>
72    where
73        M: RpcMsg<S>,
74    {
75        let msg = msg.into();
76        let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?;
77        send.send(msg).await.map_err(Error::<C>::Send)?;
78        let res = recv
79            .next()
80            .await
81            .ok_or(Error::<C>::EarlyClose)?
82            .map_err(Error::<C>::RecvError)?;
83        // keep send alive until we have the answer
84        drop(send);
85        M::Response::try_from(res).map_err(|_| Error::DowncastError)
86    }
87}
88
89impl<S, C> RpcChannel<S, C>
90where
91    S: Service,
92    C: StreamTypes<In = S::Req, Out = S::Res>,
93{
94    /// handle the message of type `M` using the given function on the target object
95    ///
96    /// If you want to support concurrent requests, you need to spawn this on a tokio task yourself.
97    pub async fn rpc<M, F, Fut, T>(
98        self,
99        req: M,
100        target: T,
101        f: F,
102    ) -> result::Result<(), RpcServerError<C>>
103    where
104        M: RpcMsg<S>,
105        F: FnOnce(T, M) -> Fut,
106        Fut: Future<Output = M::Response>,
107        T: Send + 'static,
108    {
109        let Self {
110            mut send, mut recv, ..
111        } = self;
112        // cancel if we get an update, no matter what it is
113        let cancel = recv
114            .next()
115            .map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
116        // race the computation and the cancellation
117        race2(cancel.map(Err), async move {
118            // get the response
119            let res = f(target, req).await;
120            // turn into a S::Res so we can send it
121            let res = res.into();
122            // send it and return the error if any
123            send.send(res).await.map_err(RpcServerError::SendError)
124        })
125        .await
126    }
127
128    /// A rpc call that also maps the error from the user type to the wire type
129    ///
130    /// This is useful if you want to write your function with a convenient error type like anyhow::Error,
131    /// yet still use a serializable error type on the wire.
132    pub async fn rpc_map_err<M, F, Fut, T, R, E1, E2>(
133        self,
134        req: M,
135        target: T,
136        f: F,
137    ) -> result::Result<(), RpcServerError<C>>
138    where
139        M: RpcMsg<S, Response = result::Result<R, E2>>,
140        F: FnOnce(T, M) -> Fut,
141        Fut: Future<Output = result::Result<R, E1>>,
142        E2: From<E1>,
143        T: Send + 'static,
144    {
145        let fut = |target: T, msg: M| async move {
146            // call the inner fn
147            let res: Result<R, E1> = f(target, msg).await;
148            // convert the error type
149            let res: Result<R, E2> = res.map_err(E2::from);
150            res
151        };
152        self.rpc(req, target, fut).await
153    }
154}