quic_rpc/pattern/
rpc.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
//! RPC interaction pattern.

use std::{
    error,
    fmt::{self, Debug},
    result,
};

use futures_lite::{Future, StreamExt};
use futures_util::{FutureExt, SinkExt};

use crate::{
    message::{InteractionPattern, Msg},
    server::{race2, RpcChannel, RpcServerError},
    transport::{ConnectionErrors, StreamTypes},
    Connector, RpcClient, Service,
};

/// Rpc interaction pattern
///
/// There is only one request and one response.
#[derive(Debug, Clone, Copy)]
pub struct Rpc;
impl InteractionPattern for Rpc {}

/// Defines the response type for a rpc message.
///
/// Since this is the most common interaction pattern, this also implements [Msg] for you
/// automatically, with the interaction pattern set to [Rpc]. This is to reduce boilerplate
/// when defining rpc messages.
pub trait RpcMsg<S: Service>: Msg<S, Pattern = Rpc> {
    /// The type for the response
    ///
    /// For requests that can produce errors, this can be set to [Result<T, E>](std::result::Result).
    type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
}

/// We can only do this for one trait, so we do it for RpcMsg since it is the most common
impl<T: RpcMsg<S>, S: Service> Msg<S> for T {
    type Pattern = Rpc;
}
/// Client error. All client DSL methods return a `Result` with this error type.
#[derive(Debug)]
pub enum Error<C: ConnectionErrors> {
    /// Unable to open a substream at all
    Open(C::OpenError),
    /// Unable to send the request to the server
    Send(C::SendError),
    /// Server closed the stream before sending a response
    EarlyClose,
    /// Unable to receive the response from the server
    RecvError(C::RecvError),
    /// Unexpected response from the server
    DowncastError,
}

impl<C: ConnectionErrors> fmt::Display for Error<C> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        fmt::Debug::fmt(self, f)
    }
}

impl<C: ConnectionErrors> error::Error for Error<C> {}

impl<S, C> RpcClient<S, C>
where
    S: Service,
    C: Connector<S>,
{
    /// RPC call to the server, single request, single response
    pub async fn rpc<M>(&self, msg: M) -> result::Result<M::Response, Error<C>>
    where
        M: RpcMsg<S>,
    {
        let msg = msg.into();
        let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?;
        send.send(msg).await.map_err(Error::<C>::Send)?;
        let res = recv
            .next()
            .await
            .ok_or(Error::<C>::EarlyClose)?
            .map_err(Error::<C>::RecvError)?;
        // keep send alive until we have the answer
        drop(send);
        M::Response::try_from(res).map_err(|_| Error::DowncastError)
    }
}

impl<S, C> RpcChannel<S, C>
where
    S: Service,
    C: StreamTypes<In = S::Req, Out = S::Res>,
{
    /// handle the message of type `M` using the given function on the target object
    ///
    /// If you want to support concurrent requests, you need to spawn this on a tokio task yourself.
    pub async fn rpc<M, F, Fut, T>(
        self,
        req: M,
        target: T,
        f: F,
    ) -> result::Result<(), RpcServerError<C>>
    where
        M: RpcMsg<S>,
        F: FnOnce(T, M) -> Fut,
        Fut: Future<Output = M::Response>,
        T: Send + 'static,
    {
        let Self {
            mut send, mut recv, ..
        } = self;
        // cancel if we get an update, no matter what it is
        let cancel = recv
            .next()
            .map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
        // race the computation and the cancellation
        race2(cancel.map(Err), async move {
            // get the response
            let res = f(target, req).await;
            // turn into a S::Res so we can send it
            let res = res.into();
            // send it and return the error if any
            send.send(res).await.map_err(RpcServerError::SendError)
        })
        .await
    }

    /// A rpc call that also maps the error from the user type to the wire type
    ///
    /// This is useful if you want to write your function with a convenient error type like anyhow::Error,
    /// yet still use a serializable error type on the wire.
    pub async fn rpc_map_err<M, F, Fut, T, R, E1, E2>(
        self,
        req: M,
        target: T,
        f: F,
    ) -> result::Result<(), RpcServerError<C>>
    where
        M: RpcMsg<S, Response = result::Result<R, E2>>,
        F: FnOnce(T, M) -> Fut,
        Fut: Future<Output = result::Result<R, E1>>,
        E2: From<E1>,
        T: Send + 'static,
    {
        let fut = |target: T, msg: M| async move {
            // call the inner fn
            let res: Result<R, E1> = f(target, msg).await;
            // convert the error type
            let res: Result<R, E2> = res.map_err(E2::from);
            res
        };
        self.rpc(req, target, fut).await
    }
}