quic_rpc/pattern/rpc.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
//! RPC interaction pattern.
use std::{
error,
fmt::{self, Debug},
result,
};
use futures_lite::{Future, StreamExt};
use futures_util::{FutureExt, SinkExt};
use crate::{
message::{InteractionPattern, Msg},
server::{race2, RpcChannel, RpcServerError},
transport::{ConnectionErrors, StreamTypes},
Connector, RpcClient, Service,
};
/// Rpc interaction pattern
///
/// There is only one request and one response.
#[derive(Debug, Clone, Copy)]
pub struct Rpc;
impl InteractionPattern for Rpc {}
/// Defines the response type for a rpc message.
///
/// Since this is the most common interaction pattern, this also implements [Msg] for you
/// automatically, with the interaction pattern set to [Rpc]. This is to reduce boilerplate
/// when defining rpc messages.
pub trait RpcMsg<S: Service>: Msg<S, Pattern = Rpc> {
/// The type for the response
///
/// For requests that can produce errors, this can be set to [Result<T, E>](std::result::Result).
type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
}
/// We can only do this for one trait, so we do it for RpcMsg since it is the most common
impl<T: RpcMsg<S>, S: Service> Msg<S> for T {
type Pattern = Rpc;
}
/// Client error. All client DSL methods return a `Result` with this error type.
#[derive(Debug)]
pub enum Error<C: ConnectionErrors> {
/// Unable to open a substream at all
Open(C::OpenError),
/// Unable to send the request to the server
Send(C::SendError),
/// Server closed the stream before sending a response
EarlyClose,
/// Unable to receive the response from the server
RecvError(C::RecvError),
/// Unexpected response from the server
DowncastError,
}
impl<C: ConnectionErrors> fmt::Display for Error<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
impl<C: ConnectionErrors> error::Error for Error<C> {}
impl<S, C> RpcClient<S, C>
where
S: Service,
C: Connector<S>,
{
/// RPC call to the server, single request, single response
pub async fn rpc<M>(&self, msg: M) -> result::Result<M::Response, Error<C>>
where
M: RpcMsg<S>,
{
let msg = msg.into();
let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?;
send.send(msg).await.map_err(Error::<C>::Send)?;
let res = recv
.next()
.await
.ok_or(Error::<C>::EarlyClose)?
.map_err(Error::<C>::RecvError)?;
// keep send alive until we have the answer
drop(send);
M::Response::try_from(res).map_err(|_| Error::DowncastError)
}
}
impl<S, C> RpcChannel<S, C>
where
S: Service,
C: StreamTypes<In = S::Req, Out = S::Res>,
{
/// handle the message of type `M` using the given function on the target object
///
/// If you want to support concurrent requests, you need to spawn this on a tokio task yourself.
pub async fn rpc<M, F, Fut, T>(
self,
req: M,
target: T,
f: F,
) -> result::Result<(), RpcServerError<C>>
where
M: RpcMsg<S>,
F: FnOnce(T, M) -> Fut,
Fut: Future<Output = M::Response>,
T: Send + 'static,
{
let Self {
mut send, mut recv, ..
} = self;
// cancel if we get an update, no matter what it is
let cancel = recv
.next()
.map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
// race the computation and the cancellation
race2(cancel.map(Err), async move {
// get the response
let res = f(target, req).await;
// turn into a S::Res so we can send it
let res = res.into();
// send it and return the error if any
send.send(res).await.map_err(RpcServerError::SendError)
})
.await
}
/// A rpc call that also maps the error from the user type to the wire type
///
/// This is useful if you want to write your function with a convenient error type like anyhow::Error,
/// yet still use a serializable error type on the wire.
pub async fn rpc_map_err<M, F, Fut, T, R, E1, E2>(
self,
req: M,
target: T,
f: F,
) -> result::Result<(), RpcServerError<C>>
where
M: RpcMsg<S, Response = result::Result<R, E2>>,
F: FnOnce(T, M) -> Fut,
Fut: Future<Output = result::Result<R, E1>>,
E2: From<E1>,
T: Send + 'static,
{
let fut = |target: T, msg: M| async move {
// call the inner fn
let res: Result<R, E1> = f(target, msg).await;
// convert the error type
let res: Result<R, E2> = res.map_err(E2::from);
res
};
self.rpc(req, target, fut).await
}
}