quic_rpc/pattern/
bidi_streaming.rs1use std::{
4 error,
5 fmt::{self, Debug},
6 result,
7};
8
9use futures_lite::{Stream, StreamExt};
10use futures_util::{FutureExt, SinkExt};
11
12use crate::{
13 client::{BoxStreamSync, UpdateSink},
14 message::{InteractionPattern, Msg},
15 server::{race2, RpcChannel, RpcServerError, UpdateStream},
16 transport::{ConnectionErrors, Connector, StreamTypes},
17 RpcClient, Service,
18};
19
20#[derive(Debug, Clone, Copy)]
25pub struct BidiStreaming;
26impl InteractionPattern for BidiStreaming {}
27
28pub trait BidiStreamingMsg<S: Service>: Msg<S, Pattern = BidiStreaming> {
30 type Update: Into<S::Req> + TryFrom<S::Req> + Send + 'static;
35
36 type Response: Into<S::Res> + TryFrom<S::Res> + Send + 'static;
40}
41
42#[derive(Debug)]
44pub enum Error<C: ConnectionErrors> {
45 Open(C::OpenError),
47 Send(C::SendError),
49}
50
51impl<C: ConnectionErrors> fmt::Display for Error<C> {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 fmt::Debug::fmt(self, f)
54 }
55}
56
57impl<C: ConnectionErrors> error::Error for Error<C> {}
58
59#[derive(Debug)]
61pub enum ItemError<C: ConnectionErrors> {
62 RecvError(C::RecvError),
64 DowncastError,
66}
67
68impl<C: ConnectionErrors> fmt::Display for ItemError<C> {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 fmt::Debug::fmt(self, f)
71 }
72}
73
74impl<C: ConnectionErrors> error::Error for ItemError<C> {}
75
76impl<S, C> RpcClient<S, C>
77where
78 S: Service,
79 C: Connector<In = S::Res, Out = S::Req>,
80{
81 pub async fn bidi<M>(
83 &self,
84 msg: M,
85 ) -> result::Result<
86 (
87 UpdateSink<C, M::Update>,
88 BoxStreamSync<'static, result::Result<M::Response, ItemError<C>>>,
89 ),
90 Error<C>,
91 >
92 where
93 M: BidiStreamingMsg<S>,
94 {
95 let msg = msg.into();
96 let (mut send, recv) = self.source.open().await.map_err(Error::Open)?;
97 send.send(msg).await.map_err(Error::<C>::Send)?;
98 let send = UpdateSink::new(send);
99 let recv = Box::pin(recv.map(move |x| match x {
100 Ok(msg) => M::Response::try_from(msg).map_err(|_| ItemError::DowncastError),
101 Err(e) => Err(ItemError::RecvError(e)),
102 }));
103 Ok((send, recv))
104 }
105}
106
107impl<C, S> RpcChannel<S, C>
108where
109 C: StreamTypes<In = S::Req, Out = S::Res>,
110 S: Service,
111{
112 pub async fn bidi_streaming<M, F, Str, T>(
116 self,
117 req: M,
118 target: T,
119 f: F,
120 ) -> result::Result<(), RpcServerError<C>>
121 where
122 M: BidiStreamingMsg<S>,
123 F: FnOnce(T, M, UpdateStream<C, M::Update>) -> Str + Send + 'static,
124 Str: Stream<Item = M::Response> + Send + 'static,
125 T: Send + 'static,
126 {
127 let Self { mut send, recv, .. } = self;
128 let (updates, read_error) = UpdateStream::new(recv);
130 let responses = f(target, req, updates);
132 race2(read_error.map(Err), async move {
133 tokio::pin!(responses);
134 while let Some(response) = responses.next().await {
135 let response = response.into();
137 send.send(response)
139 .await
140 .map_err(RpcServerError::SendError)?;
141 }
142 Ok(())
143 })
144 .await
145 }
146}