1use std::{
4 error,
5 fmt::{self, Debug},
6 result,
7};
8
9use futures_lite::{Future, StreamExt};
10use futures_util::{FutureExt, SinkExt};
11
12use crate::{
13 message::{InteractionPattern, Msg},
14 server::{race2, RpcChannel, RpcServerError},
15 transport::{ConnectionErrors, StreamTypes},
16 Connector, RpcClient, Service,
17};
18
19#[derive(Debug, Clone, Copy)]
23pub struct Rpc;
24impl InteractionPattern for Rpc {}
25
26pub trait RpcMsg<S: Service>: Msg<S, Pattern = Rpc> {
32 type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
36}
37
38impl<T: RpcMsg<S>, S: Service> Msg<S> for T {
40 type Pattern = Rpc;
41}
42#[derive(Debug)]
44pub enum Error<C: ConnectionErrors> {
45 Open(C::OpenError),
47 Send(C::SendError),
49 EarlyClose,
51 RecvError(C::RecvError),
53 DowncastError,
55}
56
57impl<C: ConnectionErrors> fmt::Display for Error<C> {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 fmt::Debug::fmt(self, f)
60 }
61}
62
63impl<C: ConnectionErrors> error::Error for Error<C> {}
64
65impl<S, C> RpcClient<S, C>
66where
67 S: Service,
68 C: Connector<S>,
69{
70 pub async fn rpc<M>(&self, msg: M) -> result::Result<M::Response, Error<C>>
72 where
73 M: RpcMsg<S>,
74 {
75 let msg = msg.into();
76 let (mut send, mut recv) = self.source.open().await.map_err(Error::Open)?;
77 send.send(msg).await.map_err(Error::<C>::Send)?;
78 let res = recv
79 .next()
80 .await
81 .ok_or(Error::<C>::EarlyClose)?
82 .map_err(Error::<C>::RecvError)?;
83 drop(send);
85 M::Response::try_from(res).map_err(|_| Error::DowncastError)
86 }
87}
88
89impl<S, C> RpcChannel<S, C>
90where
91 S: Service,
92 C: StreamTypes<In = S::Req, Out = S::Res>,
93{
94 pub async fn rpc<M, F, Fut, T>(
98 self,
99 req: M,
100 target: T,
101 f: F,
102 ) -> result::Result<(), RpcServerError<C>>
103 where
104 M: RpcMsg<S>,
105 F: FnOnce(T, M) -> Fut,
106 Fut: Future<Output = M::Response>,
107 T: Send + 'static,
108 {
109 let Self {
110 mut send, mut recv, ..
111 } = self;
112 let cancel = recv
114 .next()
115 .map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
116 race2(cancel.map(Err), async move {
118 let res = f(target, req).await;
120 let res = res.into();
122 send.send(res).await.map_err(RpcServerError::SendError)
124 })
125 .await
126 }
127
128 pub async fn rpc_map_err<M, F, Fut, T, R, E1, E2>(
133 self,
134 req: M,
135 target: T,
136 f: F,
137 ) -> result::Result<(), RpcServerError<C>>
138 where
139 M: RpcMsg<S, Response = result::Result<R, E2>>,
140 F: FnOnce(T, M) -> Fut,
141 Fut: Future<Output = result::Result<R, E1>>,
142 E2: From<E1>,
143 T: Send + 'static,
144 {
145 let fut = |target: T, msg: M| async move {
146 let res: Result<R, E1> = f(target, msg).await;
148 let res: Result<R, E2> = res.map_err(E2::from);
150 res
151 };
152 self.rpc(req, target, fut).await
153 }
154}