1use std::{
5 error,
6 fmt::{self, Debug},
7 marker::PhantomData,
8 pin::Pin,
9 result,
10 sync::Arc,
11 task::{self, Poll},
12};
13
14use futures_lite::{Future, Stream, StreamExt};
15use futures_util::{SinkExt, TryStreamExt};
16use pin_project::pin_project;
17use tokio::{sync::oneshot, task::JoinSet};
18use tokio_util::task::AbortOnDropHandle;
19use tracing::{error, warn};
20
21use crate::{
22 transport::{
23 self,
24 boxed::BoxableListener,
25 mapped::{ErrorOrMapError, MappedRecvStream, MappedSendSink, MappedStreamTypes},
26 ConnectionErrors, StreamTypes,
27 },
28 Listener, RpcMessage, Service,
29};
30
31pub trait ChannelTypes<S: Service>: transport::StreamTypes<In = S::Req, Out = S::Res> {}
36
37impl<T: transport::StreamTypes<In = S::Req, Out = S::Res>, S: Service> ChannelTypes<S> for T {}
38
39pub type BoxedChannelTypes<S> = crate::transport::boxed::BoxedStreamTypes<
41 <S as crate::Service>::Req,
42 <S as crate::Service>::Res,
43>;
44
45pub type BoxedListener<S> =
47 crate::transport::boxed::BoxedListener<<S as crate::Service>::Req, <S as crate::Service>::Res>;
48
49#[cfg(feature = "flume-transport")]
50#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "flume-transport")))]
51pub type FlumeListener<S> =
53 crate::transport::flume::FlumeListener<<S as Service>::Req, <S as Service>::Res>;
54
55#[cfg(feature = "quinn-transport")]
56#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "quinn-transport")))]
57pub type QuinnListener<S> =
59 crate::transport::quinn::QuinnListener<<S as Service>::Req, <S as Service>::Res>;
60
61#[cfg(feature = "hyper-transport")]
62#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "hyper-transport")))]
63pub type HyperListener<S> =
65 crate::transport::hyper::HyperListener<<S as Service>::Req, <S as Service>::Res>;
66
67#[cfg(feature = "iroh-transport")]
68#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "iroh-transport")))]
69pub type IrohListener<S> =
71 crate::transport::iroh::IrohListener<<S as Service>::Req, <S as Service>::Res>;
72
73#[derive(Debug)]
82pub struct RpcServer<S, C = BoxedListener<S>> {
83 source: C,
88 _p: PhantomData<S>,
89}
90
91impl<S, C: Clone> Clone for RpcServer<S, C> {
92 fn clone(&self) -> Self {
93 Self {
94 source: self.source.clone(),
95 _p: PhantomData,
96 }
97 }
98}
99
100impl<S: Service, C: Listener<S>> RpcServer<S, C> {
101 pub fn new(source: C) -> Self {
106 Self {
107 source,
108 _p: PhantomData,
109 }
110 }
111
112 pub fn boxed(self) -> RpcServer<S, BoxedListener<S>>
117 where
118 C: BoxableListener<S::Req, S::Res>,
119 {
120 RpcServer::new(self.source.boxed())
121 }
122}
123
124#[derive(Debug)]
137pub struct RpcChannel<S: Service, C: ChannelTypes<S> = BoxedChannelTypes<S>> {
138 pub send: C::SendSink,
140 pub recv: C::RecvStream,
142
143 pub(crate) _p: PhantomData<S>,
144}
145
146impl<S, C> RpcChannel<S, C>
147where
148 S: Service,
149 C: StreamTypes<In = S::Req, Out = S::Res>,
150{
151 pub fn new(send: C::SendSink, recv: C::RecvStream) -> Self {
153 Self {
154 send,
155 recv,
156 _p: PhantomData,
157 }
158 }
159
160 pub fn boxed(self) -> RpcChannel<S, BoxedChannelTypes<S>>
162 where
163 C::SendError: Into<anyhow::Error> + Send + Sync + 'static,
164 C::RecvError: Into<anyhow::Error> + Send + Sync + 'static,
165 {
166 let send =
167 transport::boxed::SendSink::boxed(Box::new(self.send.sink_map_err(|e| e.into())));
168 let recv = transport::boxed::RecvStream::boxed(Box::new(self.recv.map_err(|e| e.into())));
169 RpcChannel::new(send, recv)
170 }
171
172 pub fn map<SNext>(self) -> RpcChannel<SNext, MappedStreamTypes<SNext::Req, SNext::Res, C>>
182 where
183 SNext: Service,
184 SNext::Req: TryFrom<S::Req>,
185 S::Res: From<SNext::Res>,
186 {
187 RpcChannel::new(
188 MappedSendSink::new(self.send),
189 MappedRecvStream::new(self.recv),
190 )
191 }
192}
193
194pub struct Accepting<S: Service, C: Listener<S>> {
196 send: C::SendSink,
197 recv: C::RecvStream,
198 _p: PhantomData<S>,
199}
200
201impl<S: Service, C: Listener<S>> Accepting<S, C> {
202 pub async fn read_first(self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
212 let Accepting { send, mut recv, .. } = self;
213 let request: S::Req = recv
215 .next()
216 .await
217 .ok_or(RpcServerError::EarlyClose)?
219 .map_err(RpcServerError::RecvError)?;
221 Ok((request, RpcChannel::<S, C>::new(send, recv)))
222 }
223}
224
225impl<S: Service, C: Listener<S>> RpcServer<S, C> {
226 pub async fn accept(&self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
229 let (send, recv) = self.source.accept().await.map_err(RpcServerError::Accept)?;
230 Ok(Accepting {
231 send,
232 recv,
233 _p: PhantomData,
234 })
235 }
236
237 pub fn into_inner(self) -> C {
239 self.source
240 }
241
242 pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
248 where
249 S: Service,
250 C: Listener<S>,
251 Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
252 Fut: Future<Output = Result<(), E>> + Send + 'static,
253 E: Into<anyhow::Error> + 'static,
254 {
255 let handler = Arc::new(handler);
256 let mut tasks = JoinSet::new();
257 loop {
258 tokio::select! {
259 Some(res) = tasks.join_next(), if !tasks.is_empty() => {
260 if let Err(e) = res {
261 if e.is_panic() {
262 error!("Panic handling RPC request: {e}");
263 }
264 }
265 }
266 req = self.accept() => {
267 let req = match req {
268 Ok(req) => req,
269 Err(e) => {
270 warn!("Error accepting RPC request: {e}");
271 continue;
272 }
273 };
274 let handler = handler.clone();
275 tasks.spawn(async move {
276 let (req, chan) = match req.read_first().await {
277 Ok((req, chan)) => (req, chan),
278 Err(e) => {
279 warn!("Error reading first message: {e}");
280 return;
281 }
282 };
283 if let Err(cause) = handler(req, chan).await {
284 warn!("Error handling RPC request: {}", cause.into());
285 }
286 });
287 }
288 }
289 }
290 }
291
292 pub fn spawn_accept_loop<Fun, Fut, E>(self, handler: Fun) -> AbortOnDropHandle<()>
294 where
295 S: Service,
296 C: Listener<S>,
297 Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
298 Fut: Future<Output = Result<(), E>> + Send + 'static,
299 E: Into<anyhow::Error> + 'static,
300 {
301 AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler)))
302 }
303}
304
305impl<S: Service, C: Listener<S>> AsRef<C> for RpcServer<S, C> {
306 fn as_ref(&self) -> &C {
307 &self.source
308 }
309}
310
311#[pin_project]
316#[derive(Debug)]
317pub struct UpdateStream<C, T>(
318 #[pin] C::RecvStream,
319 Option<oneshot::Sender<RpcServerError<C>>>,
320 PhantomData<T>,
321)
322where
323 C: StreamTypes;
324
325impl<C, T> UpdateStream<C, T>
326where
327 C: StreamTypes,
328 T: TryFrom<C::In>,
329{
330 pub(crate) fn new(recv: C::RecvStream) -> (Self, UnwrapToPending<RpcServerError<C>>) {
331 let (error_send, error_recv) = oneshot::channel();
332 let error_recv = UnwrapToPending(futures_lite::future::fuse(error_recv));
333 (Self(recv, Some(error_send), PhantomData), error_recv)
334 }
335}
336
337impl<C, T> Stream for UpdateStream<C, T>
338where
339 C: StreamTypes,
340 T: TryFrom<C::In>,
341{
342 type Item = T;
343
344 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
345 let mut this = self.project();
346 match Pin::new(&mut this.0).poll_next(cx) {
347 Poll::Ready(Some(msg)) => match msg {
348 Ok(msg) => {
349 let msg = T::try_from(msg).map_err(|_cause| ());
350 match msg {
351 Ok(msg) => Poll::Ready(Some(msg)),
352 Err(_cause) => {
353 if let Some(tx) = this.1.take() {
355 let _ = tx.send(RpcServerError::UnexpectedUpdateMessage);
356 }
357 Poll::Pending
358 }
359 }
360 }
361 Err(cause) => {
362 if let Some(tx) = this.1.take() {
364 let _ = tx.send(RpcServerError::RecvError(cause));
365 }
366 Poll::Pending
367 }
368 },
369 Poll::Ready(None) => Poll::Ready(None),
370 Poll::Pending => Poll::Pending,
371 }
372 }
373}
374
375pub enum RpcServerError<C: ConnectionErrors> {
377 Accept(C::AcceptError),
379 EarlyClose,
381 UnexpectedStartMessage,
383 RecvError(C::RecvError),
385 SendError(C::SendError),
387 UnexpectedUpdateMessage,
389}
390
391impl<In: RpcMessage, Out: RpcMessage, C: ConnectionErrors>
392 RpcServerError<MappedStreamTypes<In, Out, C>>
393{
394 pub fn map_back(self) -> RpcServerError<C> {
396 match self {
397 RpcServerError::EarlyClose => RpcServerError::EarlyClose,
398 RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage,
399 RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage,
400 RpcServerError::SendError(x) => RpcServerError::SendError(x),
401 RpcServerError::Accept(x) => RpcServerError::Accept(x),
402 RpcServerError::RecvError(ErrorOrMapError::Inner(x)) => RpcServerError::RecvError(x),
403 RpcServerError::RecvError(ErrorOrMapError::Conversion) => {
404 RpcServerError::UnexpectedUpdateMessage
405 }
406 }
407 }
408}
409
410impl<C: ConnectionErrors> RpcServerError<C> {
411 pub fn errors_into<T>(self) -> RpcServerError<T>
413 where
414 T: ConnectionErrors,
415 C::SendError: Into<T::SendError>,
416 C::RecvError: Into<T::RecvError>,
417 C::AcceptError: Into<T::AcceptError>,
418 {
419 match self {
420 RpcServerError::EarlyClose => RpcServerError::EarlyClose,
421 RpcServerError::UnexpectedStartMessage => RpcServerError::UnexpectedStartMessage,
422 RpcServerError::UnexpectedUpdateMessage => RpcServerError::UnexpectedUpdateMessage,
423 RpcServerError::SendError(x) => RpcServerError::SendError(x.into()),
424 RpcServerError::Accept(x) => RpcServerError::Accept(x.into()),
425 RpcServerError::RecvError(x) => RpcServerError::RecvError(x.into()),
426 }
427 }
428}
429
430impl<C: ConnectionErrors> fmt::Debug for RpcServerError<C> {
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 match self {
433 Self::Accept(arg0) => f.debug_tuple("Open").field(arg0).finish(),
434 Self::EarlyClose => write!(f, "EarlyClose"),
435 Self::RecvError(arg0) => f.debug_tuple("RecvError").field(arg0).finish(),
436 Self::SendError(arg0) => f.debug_tuple("SendError").field(arg0).finish(),
437 Self::UnexpectedStartMessage => f.debug_tuple("UnexpectedStartMessage").finish(),
438 Self::UnexpectedUpdateMessage => f.debug_tuple("UnexpectedStartMessage").finish(),
439 }
440 }
441}
442
443impl<C: ConnectionErrors> fmt::Display for RpcServerError<C> {
444 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
445 fmt::Debug::fmt(&self, f)
446 }
447}
448
449impl<C: ConnectionErrors> error::Error for RpcServerError<C> {}
450
451pub(crate) struct UnwrapToPending<T>(futures_lite::future::Fuse<oneshot::Receiver<T>>);
453
454impl<T> Future for UnwrapToPending<T> {
455 type Output = T;
456
457 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
458 match Pin::new(&mut self.0).poll(cx) {
460 Poll::Ready(Ok(x)) => Poll::Ready(x),
461 Poll::Ready(Err(_)) => Poll::Pending,
462 Poll::Pending => Poll::Pending,
463 }
464 }
465}
466
467pub(crate) async fn race2<T, A: Future<Output = T>, B: Future<Output = T>>(f1: A, f2: B) -> T {
468 tokio::select! {
469 x = f1 => x,
470 x = f2 => x,
471 }
472}
473
474pub async fn run_server_loop<S, C, T, F, Fut>(
478 _service_type: S,
479 conn: C,
480 target: T,
481 mut handler: F,
482) -> Result<(), RpcServerError<C>>
483where
484 S: Service,
485 C: Listener<S>,
486 T: Clone + Send + 'static,
487 F: FnMut(RpcChannel<S, C>, S::Req, T) -> Fut + Send + 'static,
488 Fut: Future<Output = Result<(), RpcServerError<C>>> + Send + 'static,
489{
490 let server: RpcServer<S, C> = RpcServer::<S, C>::new(conn);
491 loop {
492 let (req, chan) = server.accept().await?.read_first().await?;
493 let target = target.clone();
494 handler(chan, req, target).await?;
495 }
496}