tokio_tower/pipeline/
client.rs

1use crate::mediator;
2use crate::wrappers::*;
3use crate::Error;
4use crate::MakeTransport;
5use futures_core::{ready, stream::TryStream};
6use futures_sink::Sink;
7use pin_project::pin_project;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::sync::{atomic, Arc};
13use std::task::{Context, Poll};
14use std::{error, fmt};
15use tower_service::Service;
16
17/// A factory that makes new [`Client`] instances by creating new transports and wrapping them in
18/// fresh `Client`s.
19pub struct Maker<NT, Request> {
20    t_maker: NT,
21    _req: PhantomData<fn(Request)>,
22}
23
24impl<NT, Request> fmt::Debug for Maker<NT, Request>
25where
26    NT: fmt::Debug,
27{
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("Maker")
30            .field("t_maker", &self.t_maker)
31            .finish()
32    }
33}
34
35impl<NT, Request> Maker<NT, Request> {
36    /// Make a new `Client` factory that uses the given `MakeTransport` factory.
37    pub fn new(t: NT) -> Self {
38        Maker {
39            t_maker: t,
40            _req: PhantomData,
41        }
42    }
43
44    // NOTE: it'd be *great* if the user had a way to specify a service error handler for all
45    // spawned services, but without https://github.com/rust-lang/rust/pull/49224 or
46    // https://github.com/rust-lang/rust/issues/29625 that's pretty tricky (unless we're willing to
47    // require Fn + Clone)
48}
49
50/// A failure to spawn a new `Client`.
51#[derive(Debug)]
52pub enum SpawnError<E> {
53    /// The executor failed to spawn the `tower_buffer::Worker`.
54    SpawnFailed,
55
56    /// A new transport could not be produced.
57    Inner(E),
58}
59
60impl<NT, Target, Request> Service<Target> for Maker<NT, Request>
61where
62    NT: MakeTransport<Target, Request>,
63    NT::Transport: 'static + Send,
64    Request: 'static + Send,
65    NT::Item: 'static + Send,
66    NT::SinkError: 'static + Send + Sync,
67    NT::Error: 'static + Send + Sync,
68    NT::Future: 'static + Send,
69{
70    type Error = SpawnError<NT::MakeError>;
71    type Response = Client<NT::Transport, Error<NT::Transport, Request>, Request>;
72    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
73
74    fn call(&mut self, target: Target) -> Self::Future {
75        let maker = self.t_maker.make_transport(target);
76        Box::pin(async move { Ok(Client::new(maker.await.map_err(SpawnError::Inner)?)) })
77    }
78
79    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        self.t_maker.poll_ready(cx).map_err(SpawnError::Inner)
81    }
82}
83
84impl<NT, Request> tower::load::Load for Maker<NT, Request> {
85    type Metric = u8;
86
87    fn load(&self) -> Self::Metric {
88        0
89    }
90}
91
92// ===== Client =====
93
94/// This type provides an implementation of a Tower
95/// [`Service`](https://docs.rs/tokio-service/0.1/tokio_service/trait.Service.html) on top of a
96/// request-at-a-time protocol transport. In particular, it wraps a transport that implements
97/// `Sink<SinkItem = Request>` and `Stream<Item = Response>` with the necessary bookkeeping to
98/// adhere to Tower's convenient `fn(Request) -> Future<Response>` API.
99pub struct Client<T, E, Request>
100where
101    T: Sink<Request> + TryStream,
102{
103    mediator: mediator::Sender<ClientRequest<T, Request>>,
104    in_flight: Arc<atomic::AtomicUsize>,
105    _error: PhantomData<fn(E)>,
106}
107
108impl<T, E, Request> fmt::Debug for Client<T, E, Request>
109where
110    T: Sink<Request> + TryStream,
111{
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        f.debug_struct("Client")
114            .field("mediator", &self.mediator)
115            .field("in_flight", &self.in_flight)
116            .finish()
117    }
118}
119
120// ===== ClientInner =====
121
122struct Pending<Item> {
123    tx: tokio::sync::oneshot::Sender<ClientResponse<Item>>,
124    span: tracing::Span,
125}
126
127#[pin_project]
128struct ClientInner<T, E, Request>
129where
130    T: Sink<Request> + TryStream,
131{
132    mediator: mediator::Receiver<ClientRequest<T, Request>>,
133    responses: VecDeque<Pending<T::Ok>>,
134    #[pin]
135    transport: T,
136
137    in_flight: Arc<atomic::AtomicUsize>,
138    finish: bool,
139    rx_only: bool,
140
141    #[allow(unused)]
142    error: PhantomData<fn(E)>,
143}
144
145impl<T, E, Request> Client<T, E, Request>
146where
147    T: Sink<Request> + TryStream + Send + 'static,
148    E: From<Error<T, Request>>,
149    E: 'static + Send,
150    Request: 'static + Send,
151    T::Ok: 'static + Send,
152{
153    /// Construct a new [`Client`] over the given `transport`.
154    ///
155    /// If the Client errors, the error is dropped when `new` is used -- use `with_error_handler`
156    /// to handle such an error explicitly.
157    pub fn new(transport: T) -> Self where {
158        Self::with_error_handler(transport, |_| {})
159    }
160
161    /// Construct a new [`Client`] over the given `transport`.
162    ///
163    /// If the `Client` errors, its error is passed to `on_service_error`.
164    pub fn with_error_handler<F>(transport: T, on_service_error: F) -> Self
165    where
166        F: FnOnce(E) + Send + 'static,
167    {
168        let (tx, rx) = mediator::new();
169        let in_flight = Arc::new(atomic::AtomicUsize::new(0));
170        tokio::spawn({
171            let c = ClientInner {
172                mediator: rx,
173                responses: Default::default(),
174                transport,
175                in_flight: in_flight.clone(),
176                error: PhantomData::<fn(E)>,
177                finish: false,
178                rx_only: false,
179            };
180            async move {
181                if let Err(e) = c.await {
182                    on_service_error(e);
183                }
184            }
185        });
186        Client {
187            mediator: tx,
188            in_flight,
189            _error: PhantomData,
190        }
191    }
192}
193
194impl<T, E, Request> Future for ClientInner<T, E, Request>
195where
196    T: Sink<Request> + TryStream,
197    E: From<Error<T, Request>>,
198    E: 'static + Send,
199    Request: 'static + Send,
200    T::Ok: 'static + Send,
201{
202    type Output = Result<(), E>;
203
204    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205        // go through the deref so we can do partial borrows
206        let this = self.project();
207
208        // we never move transport, nor do we ever hand out &mut to it
209        let mut transport: Pin<_> = this.transport;
210
211        // track how many times we have iterated
212        let mut i = 0;
213
214        if !*this.finish {
215            while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
216                if let Err(e) = r {
217                    return Poll::Ready(Err(E::from(Error::from_sink_error(e))));
218                }
219
220                // send more requests if we have them
221                match this.mediator.try_recv(cx) {
222                    Poll::Ready(Some(ClientRequest {
223                        req,
224                        span: _span,
225                        res,
226                    })) => {
227                        let guard = _span.enter();
228                        tracing::trace!("request received by worker; sending to Sink");
229
230                        transport
231                            .as_mut()
232                            .start_send(req)
233                            .map_err(Error::from_sink_error)?;
234                        tracing::trace!("request sent");
235                        drop(guard);
236
237                        this.responses.push_back(Pending {
238                            tx: res,
239                            span: _span,
240                        });
241                        this.in_flight.fetch_add(1, atomic::Ordering::AcqRel);
242
243                        // if we have run for a while without yielding, yield so we can make progress
244                        i += 1;
245                        if i == crate::YIELD_EVERY {
246                            // we're forcing a yield, so need to ensure we get woken up again
247                            cx.waker().wake_by_ref();
248                            // we still want to execute the code below the loop
249                            break;
250                        }
251                    }
252                    Poll::Ready(None) => {
253                        // XXX: should we "give up" the Sink::poll_ready here?
254                        *this.finish = true;
255                        break;
256                    }
257                    Poll::Pending => {
258                        // XXX: should we "give up" the Sink::poll_ready here?
259                        break;
260                    }
261                }
262            }
263        }
264
265        if this.in_flight.load(atomic::Ordering::Acquire) != 0 && !*this.rx_only {
266            // flush out any stuff we've sent in the past
267            // don't return on NotReady since we have to check for responses too
268            if *this.finish {
269                // we're closing up shop!
270                //
271                // poll_close() implies poll_flush()
272                let r = transport
273                    .as_mut()
274                    .poll_close(cx)
275                    .map_err(Error::from_sink_error)?;
276
277                if r.is_ready() {
278                    // now that close has completed, we should never send anything again
279                    // we only need to receive to make the in-flight requests complete
280                    *this.rx_only = true;
281                }
282            } else {
283                let _ = transport
284                    .as_mut()
285                    .poll_flush(cx)
286                    .map_err(Error::from_sink_error)?;
287            }
288        }
289
290        // and start looking for replies.
291        //
292        // note that we *could* have this just be a loop, but we don't want to poll the stream
293        // if we know there's nothing for it to produce.
294        while this.in_flight.load(atomic::Ordering::Acquire) != 0 {
295            match ready!(transport.as_mut().try_poll_next(cx))
296                .transpose()
297                .map_err(Error::from_stream_error)?
298            {
299                Some(r) => {
300                    // ignore send failures
301                    // the client may just no longer care about the response
302                    let pending = this.responses.pop_front().ok_or(Error::Desynchronized)?;
303                    tracing::trace!(parent: &pending.span, "response arrived; forwarding");
304
305                    let sender = pending.tx;
306                    let _ = sender.send(ClientResponse {
307                        response: r,
308                        span: pending.span,
309                    });
310                    this.in_flight.fetch_sub(1, atomic::Ordering::AcqRel);
311                }
312                None => {
313                    // the transport terminated while we were waiting for a response!
314                    // TODO: it'd be nice if we could return the transport here..
315                    return Poll::Ready(Err(E::from(Error::BrokenTransportRecv(None))));
316                }
317            }
318        }
319
320        if *this.finish && this.in_flight.load(atomic::Ordering::Acquire) == 0 {
321            if *this.rx_only {
322                // we have already closed the send side.
323            } else {
324                // we're completely done once close() finishes!
325                ready!(transport.poll_close(cx)).map_err(Error::from_sink_error)?;
326            }
327            return Poll::Ready(Ok(()));
328        }
329
330        // to get here, we must have no requests in flight and have gotten a NotReady from
331        // self.mediator.try_recv or self.transport.start_send. we *could* also have messages
332        // waiting to be sent (transport.poll_complete), but if that's the case it must also have
333        // returned NotReady. so, at this point, we know that all of our subtasks are either done
334        // or have returned NotReady, so the right thing for us to do is return NotReady too!
335        Poll::Pending
336    }
337}
338
339impl<T, E, Request> Service<Request> for Client<T, E, Request>
340where
341    T: Sink<Request> + TryStream,
342    E: From<Error<T, Request>>,
343    E: 'static + Send,
344    Request: 'static + Send,
345    T: 'static,
346    T::Ok: 'static + Send,
347{
348    type Response = T::Ok;
349    type Error = E;
350    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
351
352    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
353        Poll::Ready(ready!(self.mediator.poll_ready(cx)).map_err(|_| E::from(Error::ClientDropped)))
354    }
355
356    fn call(&mut self, req: Request) -> Self::Future {
357        let (tx, rx) = tokio::sync::oneshot::channel();
358        let span = tracing::Span::current();
359        tracing::trace!("issuing request");
360        let req = ClientRequest { req, span, res: tx };
361        let r = self.mediator.try_send(req);
362        Box::pin(async move {
363            match r {
364                Ok(()) => match rx.await {
365                    Ok(r) => {
366                        tracing::trace!(parent: &r.span, "response returned");
367                        Ok(r.response)
368                    }
369                    Err(_) => Err(E::from(Error::ClientDropped)),
370                },
371                Err(_) => Err(E::from(Error::TransportFull)),
372            }
373        })
374    }
375}
376
377impl<T, E, Request> tower::load::Load for Client<T, E, Request>
378where
379    T: Sink<Request> + TryStream,
380{
381    type Metric = usize;
382
383    fn load(&self) -> Self::Metric {
384        self.in_flight.load(atomic::Ordering::Acquire)
385    }
386}
387
388// ===== impl SpawnError =====
389
390impl<T> fmt::Display for SpawnError<T>
391where
392    T: fmt::Debug,
393{
394    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
395        match *self {
396            SpawnError::SpawnFailed => f.pad("error spawning multiplex client"),
397            SpawnError::Inner(_) => f.pad("error making new multiplex transport"),
398        }
399    }
400}
401
402impl<T> error::Error for SpawnError<T>
403where
404    T: error::Error + 'static,
405{
406    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
407        match *self {
408            SpawnError::SpawnFailed => None,
409            SpawnError::Inner(ref te) => Some(te),
410        }
411    }
412}