1use crate::mediator;
2use crate::wrappers::*;
3use crate::Error;
4use crate::MakeTransport;
5use futures_core::{ready, stream::TryStream};
6use futures_sink::Sink;
7use pin_project::pin_project;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::sync::{atomic, Arc};
13use std::task::{Context, Poll};
14use std::{error, fmt};
15use tower_service::Service;
16
17pub trait TagStore<Request, Response> {
27 type Tag: Eq;
29
30 fn assign_tag(self: Pin<&mut Self>, r: &mut Request) -> Self::Tag;
32
33 fn finish_tag(self: Pin<&mut Self>, r: &Response) -> Self::Tag;
35}
36
37pub struct Maker<NT, Request> {
40 t_maker: NT,
41 _req: PhantomData<fn(Request)>,
42}
43
44impl<NT, Request> fmt::Debug for Maker<NT, Request>
45where
46 NT: fmt::Debug,
47{
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("Maker")
50 .field("t_maker", &self.t_maker)
51 .finish()
52 }
53}
54
55impl<NT, Request> Maker<NT, Request> {
56 pub fn new(t: NT) -> Self {
58 Maker {
59 t_maker: t,
60 _req: PhantomData,
61 }
62 }
63
64 }
69
70#[derive(Debug)]
72pub enum SpawnError<E> {
73 SpawnFailed,
75
76 Inner(E),
78}
79
80impl<NT, Target, Request> Service<Target> for Maker<NT, Request>
81where
82 NT: MakeTransport<Target, Request>,
83 NT::Transport: 'static + Send + TagStore<Request, NT::Item>,
84 <NT::Transport as TagStore<Request, NT::Item>>::Tag: 'static + Send,
85 Request: 'static + Send,
86 NT::Item: 'static + Send,
87 NT::SinkError: 'static + Send + Sync,
88 NT::Error: 'static + Send + Sync,
89 NT::Future: 'static + Send,
90{
91 type Error = SpawnError<NT::MakeError>;
92 type Response = Client<NT::Transport, Error<NT::Transport, Request>, Request>;
93 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
94
95 fn call(&mut self, target: Target) -> Self::Future {
96 let maker = self.t_maker.make_transport(target);
97 Box::pin(async move { Ok(Client::new(maker.await.map_err(SpawnError::Inner)?)) })
98 }
99
100 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101 self.t_maker.poll_ready(cx).map_err(SpawnError::Inner)
102 }
103}
104
105impl<NT, Request> tower::load::Load for Maker<NT, Request> {
106 type Metric = u8;
107
108 fn load(&self) -> Self::Metric {
109 0
110 }
111}
112
113pub struct Client<T, E, Request>
121where
122 T: Sink<Request> + TryStream,
123{
124 mediator: mediator::Sender<ClientRequest<T, Request>>,
125 in_flight: Arc<atomic::AtomicUsize>,
126 _error: PhantomData<fn(E)>,
127}
128
129impl<T, E, Request> fmt::Debug for Client<T, E, Request>
130where
131 T: Sink<Request> + TryStream,
132{
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("Client")
135 .field("mediator", &self.mediator)
136 .field("in_flight", &self.in_flight)
137 .finish()
138 }
139}
140
141struct Pending<Tag, Item> {
144 tag: Tag,
145 tx: tokio::sync::oneshot::Sender<ClientResponse<Item>>,
146 span: tracing::Span,
147}
148
149#[pin_project]
150struct ClientInner<T, E, Request>
151where
152 T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok>,
153{
154 mediator: mediator::Receiver<ClientRequest<T, Request>>,
155 responses: VecDeque<Pending<T::Tag, T::Ok>>,
156 #[pin]
157 transport: T,
158
159 in_flight: Arc<atomic::AtomicUsize>,
160 finish: bool,
161 rx_only: bool,
162
163 #[allow(unused)]
164 error: PhantomData<fn(E)>,
165}
166
167impl<T, E, Request> Client<T, E, Request>
168where
169 T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok> + Send + 'static,
170 E: From<Error<T, Request>>,
171 E: 'static + Send,
172 Request: 'static + Send,
173 T::Ok: 'static + Send,
174 T::Tag: Send,
175{
176 pub fn new(transport: T) -> Self where {
181 Self::with_error_handler(transport, |_| {})
182 }
183
184 pub fn with_error_handler<F>(transport: T, on_service_error: F) -> Self
188 where
189 F: FnOnce(E) + Send + 'static,
190 {
191 let (tx, rx) = mediator::new();
192 let in_flight = Arc::new(atomic::AtomicUsize::new(0));
193 tokio::spawn({
194 let c = ClientInner {
195 mediator: rx,
196 responses: Default::default(),
197 transport,
198 in_flight: in_flight.clone(),
199 error: PhantomData::<fn(E)>,
200 finish: false,
201 rx_only: false,
202 };
203 async move {
204 if let Err(e) = c.await {
205 on_service_error(e);
206 }
207 }
208 });
209 Client {
210 mediator: tx,
211 in_flight,
212 _error: PhantomData,
213 }
214 }
215}
216
217impl<T, E, Request> Future for ClientInner<T, E, Request>
218where
219 T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok>,
220 E: From<Error<T, Request>>,
221 E: 'static + Send,
222 Request: 'static + Send,
223 T::Ok: 'static + Send,
224{
225 type Output = Result<(), E>;
226
227 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228 let this = self.project();
230
231 let mut transport: Pin<_> = this.transport;
233
234 let mut i = 0;
236
237 if !*this.finish {
238 while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
239 if let Err(e) = r {
240 return Poll::Ready(Err(E::from(Error::from_sink_error(e))));
241 }
242
243 match this.mediator.try_recv(cx) {
245 Poll::Ready(Some(ClientRequest {
246 mut req,
247 span: _span,
248 res,
249 })) => {
250 let id = transport.as_mut().assign_tag(&mut req);
251
252 let guard = _span.enter();
253 tracing::trace!("request received by worker; sending to Sink");
254
255 transport
256 .as_mut()
257 .start_send(req)
258 .map_err(Error::from_sink_error)?;
259 tracing::trace!("request sent");
260 drop(guard);
261
262 this.responses.push_back(Pending {
263 tag: id,
264 tx: res,
265 span: _span,
266 });
267 this.in_flight.fetch_add(1, atomic::Ordering::AcqRel);
268
269 i += 1;
271 if i == crate::YIELD_EVERY {
272 cx.waker().wake_by_ref();
274 break;
276 }
277 }
278 Poll::Ready(None) => {
279 *this.finish = true;
281 break;
282 }
283 Poll::Pending => {
284 break;
286 }
287 }
288 }
289 }
290
291 if this.in_flight.load(atomic::Ordering::Acquire) != 0 && !*this.rx_only {
292 if *this.finish {
295 let r = transport
299 .as_mut()
300 .poll_close(cx)
301 .map_err(Error::from_sink_error)?;
302
303 if r.is_ready() {
304 *this.rx_only = true;
307 }
308 } else {
309 let _ = transport
310 .as_mut()
311 .poll_flush(cx)
312 .map_err(Error::from_sink_error)?;
313 }
314 }
315
316 while this.in_flight.load(atomic::Ordering::Acquire) != 0 {
321 match ready!(transport.as_mut().try_poll_next(cx))
322 .transpose()
323 .map_err(Error::from_stream_error)?
324 {
325 Some(r) => {
326 let id = transport.as_mut().finish_tag(&r);
331 let pending = this
332 .responses
333 .iter()
334 .position(|&Pending { ref tag, .. }| tag == &id)
335 .ok_or(Error::Desynchronized)?;
336
337 let pending = this.responses.swap_remove_front(pending).unwrap();
341 tracing::trace!(parent: &pending.span, "response arrived; forwarding");
342
343 let sender = pending.tx;
346 let _ = sender.send(ClientResponse {
347 response: r,
348 span: pending.span,
349 });
350 this.in_flight.fetch_sub(1, atomic::Ordering::AcqRel);
351 }
352 None => {
353 return Poll::Ready(Err(E::from(Error::BrokenTransportRecv(None))));
356 }
357 }
358 }
359
360 if *this.finish && this.in_flight.load(atomic::Ordering::Acquire) == 0 {
361 if *this.rx_only {
362 } else {
364 ready!(transport.poll_close(cx)).map_err(Error::from_sink_error)?;
366 }
367 return Poll::Ready(Ok(()));
368 }
369
370 Poll::Pending
376 }
377}
378
379impl<T, E, Request> Service<Request> for Client<T, E, Request>
380where
381 T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok>,
382 E: From<Error<T, Request>>,
383 E: 'static + Send,
384 Request: 'static + Send,
385 T: 'static,
386 T::Ok: 'static + Send,
387{
388 type Response = T::Ok;
389 type Error = E;
390 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
391
392 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
393 Poll::Ready(ready!(self.mediator.poll_ready(cx)).map_err(|_| E::from(Error::ClientDropped)))
394 }
395
396 fn call(&mut self, req: Request) -> Self::Future {
397 let (tx, rx) = tokio::sync::oneshot::channel();
398 let span = tracing::Span::current();
399 tracing::trace!("issuing request");
400 let req = ClientRequest { req, span, res: tx };
401 let r = self.mediator.try_send(req);
402 Box::pin(async move {
403 match r {
404 Ok(()) => match rx.await {
405 Ok(r) => {
406 tracing::trace!(parent: &r.span, "response returned");
407 Ok(r.response)
408 }
409 Err(_) => Err(E::from(Error::ClientDropped)),
410 },
411 Err(_) => Err(E::from(Error::TransportFull)),
412 }
413 })
414 }
415}
416
417impl<T, E, Request> tower::load::Load for Client<T, E, Request>
418where
419 T: Sink<Request> + TryStream,
420{
421 type Metric = usize;
422
423 fn load(&self) -> Self::Metric {
424 self.in_flight.load(atomic::Ordering::Acquire)
425 }
426}
427
428impl<T> fmt::Display for SpawnError<T>
431where
432 T: fmt::Debug,
433{
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 match *self {
436 SpawnError::SpawnFailed => f.pad("error spawning multiplex client"),
437 SpawnError::Inner(_) => f.pad("error making new multiplex transport"),
438 }
439 }
440}
441
442impl<T> error::Error for SpawnError<T>
443where
444 T: error::Error + 'static,
445{
446 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
447 match *self {
448 SpawnError::SpawnFailed => None,
449 SpawnError::Inner(ref te) => Some(te),
450 }
451 }
452}