1use crate::mediator;
2use crate::wrappers::*;
3use crate::Error;
4use crate::MakeTransport;
5use futures_core::{ready, stream::TryStream};
6use futures_sink::Sink;
7use pin_project::pin_project;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::sync::{atomic, Arc};
13use std::task::{Context, Poll};
14use std::{error, fmt};
15use tower_service::Service;
16
17pub struct Maker<NT, Request> {
20 t_maker: NT,
21 _req: PhantomData<fn(Request)>,
22}
23
24impl<NT, Request> fmt::Debug for Maker<NT, Request>
25where
26 NT: fmt::Debug,
27{
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 f.debug_struct("Maker")
30 .field("t_maker", &self.t_maker)
31 .finish()
32 }
33}
34
35impl<NT, Request> Maker<NT, Request> {
36 pub fn new(t: NT) -> Self {
38 Maker {
39 t_maker: t,
40 _req: PhantomData,
41 }
42 }
43
44 }
49
50#[derive(Debug)]
52pub enum SpawnError<E> {
53 SpawnFailed,
55
56 Inner(E),
58}
59
60impl<NT, Target, Request> Service<Target> for Maker<NT, Request>
61where
62 NT: MakeTransport<Target, Request>,
63 NT::Transport: 'static + Send,
64 Request: 'static + Send,
65 NT::Item: 'static + Send,
66 NT::SinkError: 'static + Send + Sync,
67 NT::Error: 'static + Send + Sync,
68 NT::Future: 'static + Send,
69{
70 type Error = SpawnError<NT::MakeError>;
71 type Response = Client<NT::Transport, Error<NT::Transport, Request>, Request>;
72 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
73
74 fn call(&mut self, target: Target) -> Self::Future {
75 let maker = self.t_maker.make_transport(target);
76 Box::pin(async move { Ok(Client::new(maker.await.map_err(SpawnError::Inner)?)) })
77 }
78
79 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 self.t_maker.poll_ready(cx).map_err(SpawnError::Inner)
81 }
82}
83
84impl<NT, Request> tower::load::Load for Maker<NT, Request> {
85 type Metric = u8;
86
87 fn load(&self) -> Self::Metric {
88 0
89 }
90}
91
92pub struct Client<T, E, Request>
100where
101 T: Sink<Request> + TryStream,
102{
103 mediator: mediator::Sender<ClientRequest<T, Request>>,
104 in_flight: Arc<atomic::AtomicUsize>,
105 _error: PhantomData<fn(E)>,
106}
107
108impl<T, E, Request> fmt::Debug for Client<T, E, Request>
109where
110 T: Sink<Request> + TryStream,
111{
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 f.debug_struct("Client")
114 .field("mediator", &self.mediator)
115 .field("in_flight", &self.in_flight)
116 .finish()
117 }
118}
119
120struct Pending<Item> {
123 tx: tokio::sync::oneshot::Sender<ClientResponse<Item>>,
124 span: tracing::Span,
125}
126
127#[pin_project]
128struct ClientInner<T, E, Request>
129where
130 T: Sink<Request> + TryStream,
131{
132 mediator: mediator::Receiver<ClientRequest<T, Request>>,
133 responses: VecDeque<Pending<T::Ok>>,
134 #[pin]
135 transport: T,
136
137 in_flight: Arc<atomic::AtomicUsize>,
138 finish: bool,
139 rx_only: bool,
140
141 #[allow(unused)]
142 error: PhantomData<fn(E)>,
143}
144
145impl<T, E, Request> Client<T, E, Request>
146where
147 T: Sink<Request> + TryStream + Send + 'static,
148 E: From<Error<T, Request>>,
149 E: 'static + Send,
150 Request: 'static + Send,
151 T::Ok: 'static + Send,
152{
153 pub fn new(transport: T) -> Self where {
158 Self::with_error_handler(transport, |_| {})
159 }
160
161 pub fn with_error_handler<F>(transport: T, on_service_error: F) -> Self
165 where
166 F: FnOnce(E) + Send + 'static,
167 {
168 let (tx, rx) = mediator::new();
169 let in_flight = Arc::new(atomic::AtomicUsize::new(0));
170 tokio::spawn({
171 let c = ClientInner {
172 mediator: rx,
173 responses: Default::default(),
174 transport,
175 in_flight: in_flight.clone(),
176 error: PhantomData::<fn(E)>,
177 finish: false,
178 rx_only: false,
179 };
180 async move {
181 if let Err(e) = c.await {
182 on_service_error(e);
183 }
184 }
185 });
186 Client {
187 mediator: tx,
188 in_flight,
189 _error: PhantomData,
190 }
191 }
192}
193
194impl<T, E, Request> Future for ClientInner<T, E, Request>
195where
196 T: Sink<Request> + TryStream,
197 E: From<Error<T, Request>>,
198 E: 'static + Send,
199 Request: 'static + Send,
200 T::Ok: 'static + Send,
201{
202 type Output = Result<(), E>;
203
204 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205 let this = self.project();
207
208 let mut transport: Pin<_> = this.transport;
210
211 let mut i = 0;
213
214 if !*this.finish {
215 while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
216 if let Err(e) = r {
217 return Poll::Ready(Err(E::from(Error::from_sink_error(e))));
218 }
219
220 match this.mediator.try_recv(cx) {
222 Poll::Ready(Some(ClientRequest {
223 req,
224 span: _span,
225 res,
226 })) => {
227 let guard = _span.enter();
228 tracing::trace!("request received by worker; sending to Sink");
229
230 transport
231 .as_mut()
232 .start_send(req)
233 .map_err(Error::from_sink_error)?;
234 tracing::trace!("request sent");
235 drop(guard);
236
237 this.responses.push_back(Pending {
238 tx: res,
239 span: _span,
240 });
241 this.in_flight.fetch_add(1, atomic::Ordering::AcqRel);
242
243 i += 1;
245 if i == crate::YIELD_EVERY {
246 cx.waker().wake_by_ref();
248 break;
250 }
251 }
252 Poll::Ready(None) => {
253 *this.finish = true;
255 break;
256 }
257 Poll::Pending => {
258 break;
260 }
261 }
262 }
263 }
264
265 if this.in_flight.load(atomic::Ordering::Acquire) != 0 && !*this.rx_only {
266 if *this.finish {
269 let r = transport
273 .as_mut()
274 .poll_close(cx)
275 .map_err(Error::from_sink_error)?;
276
277 if r.is_ready() {
278 *this.rx_only = true;
281 }
282 } else {
283 let _ = transport
284 .as_mut()
285 .poll_flush(cx)
286 .map_err(Error::from_sink_error)?;
287 }
288 }
289
290 while this.in_flight.load(atomic::Ordering::Acquire) != 0 {
295 match ready!(transport.as_mut().try_poll_next(cx))
296 .transpose()
297 .map_err(Error::from_stream_error)?
298 {
299 Some(r) => {
300 let pending = this.responses.pop_front().ok_or(Error::Desynchronized)?;
303 tracing::trace!(parent: &pending.span, "response arrived; forwarding");
304
305 let sender = pending.tx;
306 let _ = sender.send(ClientResponse {
307 response: r,
308 span: pending.span,
309 });
310 this.in_flight.fetch_sub(1, atomic::Ordering::AcqRel);
311 }
312 None => {
313 return Poll::Ready(Err(E::from(Error::BrokenTransportRecv(None))));
316 }
317 }
318 }
319
320 if *this.finish && this.in_flight.load(atomic::Ordering::Acquire) == 0 {
321 if *this.rx_only {
322 } else {
324 ready!(transport.poll_close(cx)).map_err(Error::from_sink_error)?;
326 }
327 return Poll::Ready(Ok(()));
328 }
329
330 Poll::Pending
336 }
337}
338
339impl<T, E, Request> Service<Request> for Client<T, E, Request>
340where
341 T: Sink<Request> + TryStream,
342 E: From<Error<T, Request>>,
343 E: 'static + Send,
344 Request: 'static + Send,
345 T: 'static,
346 T::Ok: 'static + Send,
347{
348 type Response = T::Ok;
349 type Error = E;
350 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
351
352 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
353 Poll::Ready(ready!(self.mediator.poll_ready(cx)).map_err(|_| E::from(Error::ClientDropped)))
354 }
355
356 fn call(&mut self, req: Request) -> Self::Future {
357 let (tx, rx) = tokio::sync::oneshot::channel();
358 let span = tracing::Span::current();
359 tracing::trace!("issuing request");
360 let req = ClientRequest { req, span, res: tx };
361 let r = self.mediator.try_send(req);
362 Box::pin(async move {
363 match r {
364 Ok(()) => match rx.await {
365 Ok(r) => {
366 tracing::trace!(parent: &r.span, "response returned");
367 Ok(r.response)
368 }
369 Err(_) => Err(E::from(Error::ClientDropped)),
370 },
371 Err(_) => Err(E::from(Error::TransportFull)),
372 }
373 })
374 }
375}
376
377impl<T, E, Request> tower::load::Load for Client<T, E, Request>
378where
379 T: Sink<Request> + TryStream,
380{
381 type Metric = usize;
382
383 fn load(&self) -> Self::Metric {
384 self.in_flight.load(atomic::Ordering::Acquire)
385 }
386}
387
388impl<T> fmt::Display for SpawnError<T>
391where
392 T: fmt::Debug,
393{
394 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395 match *self {
396 SpawnError::SpawnFailed => f.pad("error spawning multiplex client"),
397 SpawnError::Inner(_) => f.pad("error making new multiplex transport"),
398 }
399 }
400}
401
402impl<T> error::Error for SpawnError<T>
403where
404 T: error::Error + 'static,
405{
406 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
407 match *self {
408 SpawnError::SpawnFailed => None,
409 SpawnError::Inner(ref te) => Some(te),
410 }
411 }
412}