tokio_tower/multiplex/
client.rs

1use crate::mediator;
2use crate::wrappers::*;
3use crate::Error;
4use crate::MakeTransport;
5use futures_core::{ready, stream::TryStream};
6use futures_sink::Sink;
7use pin_project::pin_project;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::sync::{atomic, Arc};
13use std::task::{Context, Poll};
14use std::{error, fmt};
15use tower_service::Service;
16
17// NOTE: this implementation could be more opinionated about request IDs by using a slab, but
18// instead, we allow the user to choose their own identifier format.
19
20/// A transport capable of transporting tagged requests and responses must implement this
21/// interface in order to be used with a [`Client`].
22///
23/// Note that we require self to be pinned here as `assign_tag` and `finish_tag` are called on the
24/// transport, which is already pinned so that we can use it as a `Stream + Sink`. It wouldn't be
25/// safe to then give out `&mut` to the transport without `Pin`, as that might move the transport.
26pub trait TagStore<Request, Response> {
27    /// The type used for tags.
28    type Tag: Eq;
29
30    /// Assign a fresh tag to the given `Request`, and return that tag.
31    fn assign_tag(self: Pin<&mut Self>, r: &mut Request) -> Self::Tag;
32
33    /// Retire and return the tag contained in the given `Response`.
34    fn finish_tag(self: Pin<&mut Self>, r: &Response) -> Self::Tag;
35}
36
37/// A factory that makes new [`Client`] instances by creating new transports and wrapping them in
38/// fresh `Client`s.
39pub struct Maker<NT, Request> {
40    t_maker: NT,
41    _req: PhantomData<fn(Request)>,
42}
43
44impl<NT, Request> fmt::Debug for Maker<NT, Request>
45where
46    NT: fmt::Debug,
47{
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        f.debug_struct("Maker")
50            .field("t_maker", &self.t_maker)
51            .finish()
52    }
53}
54
55impl<NT, Request> Maker<NT, Request> {
56    /// Make a new `Client` factory that uses the given `MakeTransport` factory.
57    pub fn new(t: NT) -> Self {
58        Maker {
59            t_maker: t,
60            _req: PhantomData,
61        }
62    }
63
64    // NOTE: it'd be *great* if the user had a way to specify a service error handler for all
65    // spawned services, but without https://github.com/rust-lang/rust/pull/49224 or
66    // https://github.com/rust-lang/rust/issues/29625 that's pretty tricky (unless we're willing to
67    // require Fn + Clone)
68}
69
70/// A failure to spawn a new `Client`.
71#[derive(Debug)]
72pub enum SpawnError<E> {
73    /// The executor failed to spawn the `tower_buffer::Worker`.
74    SpawnFailed,
75
76    /// A new transport could not be produced.
77    Inner(E),
78}
79
80impl<NT, Target, Request> Service<Target> for Maker<NT, Request>
81where
82    NT: MakeTransport<Target, Request>,
83    NT::Transport: 'static + Send + TagStore<Request, NT::Item>,
84    <NT::Transport as TagStore<Request, NT::Item>>::Tag: 'static + Send,
85    Request: 'static + Send,
86    NT::Item: 'static + Send,
87    NT::SinkError: 'static + Send + Sync,
88    NT::Error: 'static + Send + Sync,
89    NT::Future: 'static + Send,
90{
91    type Error = SpawnError<NT::MakeError>;
92    type Response = Client<NT::Transport, Error<NT::Transport, Request>, Request>;
93    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
94
95    fn call(&mut self, target: Target) -> Self::Future {
96        let maker = self.t_maker.make_transport(target);
97        Box::pin(async move { Ok(Client::new(maker.await.map_err(SpawnError::Inner)?)) })
98    }
99
100    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101        self.t_maker.poll_ready(cx).map_err(SpawnError::Inner)
102    }
103}
104
105impl<NT, Request> tower::load::Load for Maker<NT, Request> {
106    type Metric = u8;
107
108    fn load(&self) -> Self::Metric {
109        0
110    }
111}
112
113// ===== Client =====
114
115/// This type provides an implementation of a Tower
116/// [`Service`](https://docs.rs/tokio-service/0.1/tokio_service/trait.Service.html) on top of a
117/// request-at-a-time protocol transport. In particular, it wraps a transport that implements
118/// `Sink<SinkItem = Request>` and `Stream<Item = Response>` with the necessary bookkeeping to
119/// adhere to Tower's convenient `fn(Request) -> Future<Response>` API.
120pub struct Client<T, E, Request>
121where
122    T: Sink<Request> + TryStream,
123{
124    mediator: mediator::Sender<ClientRequest<T, Request>>,
125    in_flight: Arc<atomic::AtomicUsize>,
126    _error: PhantomData<fn(E)>,
127}
128
129impl<T, E, Request> fmt::Debug for Client<T, E, Request>
130where
131    T: Sink<Request> + TryStream,
132{
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("Client")
135            .field("mediator", &self.mediator)
136            .field("in_flight", &self.in_flight)
137            .finish()
138    }
139}
140
141// ===== ClientInner =====
142
143struct Pending<Tag, Item> {
144    tag: Tag,
145    tx: tokio::sync::oneshot::Sender<ClientResponse<Item>>,
146    span: tracing::Span,
147}
148
149#[pin_project]
150struct ClientInner<T, E, Request>
151where
152    T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok>,
153{
154    mediator: mediator::Receiver<ClientRequest<T, Request>>,
155    responses: VecDeque<Pending<T::Tag, T::Ok>>,
156    #[pin]
157    transport: T,
158
159    in_flight: Arc<atomic::AtomicUsize>,
160    finish: bool,
161    rx_only: bool,
162
163    #[allow(unused)]
164    error: PhantomData<fn(E)>,
165}
166
167impl<T, E, Request> Client<T, E, Request>
168where
169    T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok> + Send + 'static,
170    E: From<Error<T, Request>>,
171    E: 'static + Send,
172    Request: 'static + Send,
173    T::Ok: 'static + Send,
174    T::Tag: Send,
175{
176    /// Construct a new [`Client`] over the given `transport`.
177    ///
178    /// If the Client errors, the error is dropped when `new` is used -- use `with_error_handler`
179    /// to handle such an error explicitly.
180    pub fn new(transport: T) -> Self where {
181        Self::with_error_handler(transport, |_| {})
182    }
183
184    /// Construct a new [`Client`] over the given `transport`.
185    ///
186    /// If the `Client` errors, its error is passed to `on_service_error`.
187    pub fn with_error_handler<F>(transport: T, on_service_error: F) -> Self
188    where
189        F: FnOnce(E) + Send + 'static,
190    {
191        let (tx, rx) = mediator::new();
192        let in_flight = Arc::new(atomic::AtomicUsize::new(0));
193        tokio::spawn({
194            let c = ClientInner {
195                mediator: rx,
196                responses: Default::default(),
197                transport,
198                in_flight: in_flight.clone(),
199                error: PhantomData::<fn(E)>,
200                finish: false,
201                rx_only: false,
202            };
203            async move {
204                if let Err(e) = c.await {
205                    on_service_error(e);
206                }
207            }
208        });
209        Client {
210            mediator: tx,
211            in_flight,
212            _error: PhantomData,
213        }
214    }
215}
216
217impl<T, E, Request> Future for ClientInner<T, E, Request>
218where
219    T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok>,
220    E: From<Error<T, Request>>,
221    E: 'static + Send,
222    Request: 'static + Send,
223    T::Ok: 'static + Send,
224{
225    type Output = Result<(), E>;
226
227    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
228        // go through the deref so we can do partial borrows
229        let this = self.project();
230
231        // we never move transport, nor do we ever hand out &mut to it
232        let mut transport: Pin<_> = this.transport;
233
234        // track how many times we have iterated
235        let mut i = 0;
236
237        if !*this.finish {
238            while let Poll::Ready(r) = transport.as_mut().poll_ready(cx) {
239                if let Err(e) = r {
240                    return Poll::Ready(Err(E::from(Error::from_sink_error(e))));
241                }
242
243                // send more requests if we have them
244                match this.mediator.try_recv(cx) {
245                    Poll::Ready(Some(ClientRequest {
246                        mut req,
247                        span: _span,
248                        res,
249                    })) => {
250                        let id = transport.as_mut().assign_tag(&mut req);
251
252                        let guard = _span.enter();
253                        tracing::trace!("request received by worker; sending to Sink");
254
255                        transport
256                            .as_mut()
257                            .start_send(req)
258                            .map_err(Error::from_sink_error)?;
259                        tracing::trace!("request sent");
260                        drop(guard);
261
262                        this.responses.push_back(Pending {
263                            tag: id,
264                            tx: res,
265                            span: _span,
266                        });
267                        this.in_flight.fetch_add(1, atomic::Ordering::AcqRel);
268
269                        // if we have run for a while without yielding, yield so we can make progress
270                        i += 1;
271                        if i == crate::YIELD_EVERY {
272                            // we're forcing a yield, so need to ensure we get woken up again
273                            cx.waker().wake_by_ref();
274                            // we still want to execute the code below the loop
275                            break;
276                        }
277                    }
278                    Poll::Ready(None) => {
279                        // XXX: should we "give up" the Sink::poll_ready here?
280                        *this.finish = true;
281                        break;
282                    }
283                    Poll::Pending => {
284                        // XXX: should we "give up" the Sink::poll_ready here?
285                        break;
286                    }
287                }
288            }
289        }
290
291        if this.in_flight.load(atomic::Ordering::Acquire) != 0 && !*this.rx_only {
292            // flush out any stuff we've sent in the past
293            // don't return on NotReady since we have to check for responses too
294            if *this.finish {
295                // we're closing up shop!
296                //
297                // poll_close() implies poll_flush()
298                let r = transport
299                    .as_mut()
300                    .poll_close(cx)
301                    .map_err(Error::from_sink_error)?;
302
303                if r.is_ready() {
304                    // now that close has completed, we should never send anything again
305                    // we only need to receive to make the in-flight requests complete
306                    *this.rx_only = true;
307                }
308            } else {
309                let _ = transport
310                    .as_mut()
311                    .poll_flush(cx)
312                    .map_err(Error::from_sink_error)?;
313            }
314        }
315
316        // and start looking for replies.
317        //
318        // note that we *could* have this just be a loop, but we don't want to poll the stream
319        // if we know there's nothing for it to produce.
320        while this.in_flight.load(atomic::Ordering::Acquire) != 0 {
321            match ready!(transport.as_mut().try_poll_next(cx))
322                .transpose()
323                .map_err(Error::from_stream_error)?
324            {
325                Some(r) => {
326                    // find the appropriate response channel.
327                    // note that we do a _linear_ scan of the identifiers. this saves us from
328                    // keeping a HashMap around, and is _usually_ fast as long as the requests
329                    // that have been pending the longest are most likely to complete next.
330                    let id = transport.as_mut().finish_tag(&r);
331                    let pending = this
332                        .responses
333                        .iter()
334                        .position(|&Pending { ref tag, .. }| tag == &id)
335                        .ok_or(Error::Desynchronized)?;
336
337                    // this request just finished, which means it's _probably_ near the front
338                    // (i.e., was issued a while ago). so, for the swap needed for efficient
339                    // remove, we want to swap with something else that is close to the front.
340                    let pending = this.responses.swap_remove_front(pending).unwrap();
341                    tracing::trace!(parent: &pending.span, "response arrived; forwarding");
342
343                    // ignore send failures
344                    // the client may just no longer care about the response
345                    let sender = pending.tx;
346                    let _ = sender.send(ClientResponse {
347                        response: r,
348                        span: pending.span,
349                    });
350                    this.in_flight.fetch_sub(1, atomic::Ordering::AcqRel);
351                }
352                None => {
353                    // the transport terminated while we were waiting for a response!
354                    // TODO: it'd be nice if we could return the transport here..
355                    return Poll::Ready(Err(E::from(Error::BrokenTransportRecv(None))));
356                }
357            }
358        }
359
360        if *this.finish && this.in_flight.load(atomic::Ordering::Acquire) == 0 {
361            if *this.rx_only {
362                // we have already closed the send side.
363            } else {
364                // we're completely done once close() finishes!
365                ready!(transport.poll_close(cx)).map_err(Error::from_sink_error)?;
366            }
367            return Poll::Ready(Ok(()));
368        }
369
370        // to get here, we must have no requests in flight and have gotten a NotReady from
371        // self.mediator.try_recv or self.transport.start_send. we *could* also have messages
372        // waiting to be sent (transport.poll_complete), but if that's the case it must also have
373        // returned NotReady. so, at this point, we know that all of our subtasks are either done
374        // or have returned NotReady, so the right thing for us to do is return NotReady too!
375        Poll::Pending
376    }
377}
378
379impl<T, E, Request> Service<Request> for Client<T, E, Request>
380where
381    T: Sink<Request> + TryStream + TagStore<Request, <T as TryStream>::Ok>,
382    E: From<Error<T, Request>>,
383    E: 'static + Send,
384    Request: 'static + Send,
385    T: 'static,
386    T::Ok: 'static + Send,
387{
388    type Response = T::Ok;
389    type Error = E;
390    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
391
392    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
393        Poll::Ready(ready!(self.mediator.poll_ready(cx)).map_err(|_| E::from(Error::ClientDropped)))
394    }
395
396    fn call(&mut self, req: Request) -> Self::Future {
397        let (tx, rx) = tokio::sync::oneshot::channel();
398        let span = tracing::Span::current();
399        tracing::trace!("issuing request");
400        let req = ClientRequest { req, span, res: tx };
401        let r = self.mediator.try_send(req);
402        Box::pin(async move {
403            match r {
404                Ok(()) => match rx.await {
405                    Ok(r) => {
406                        tracing::trace!(parent: &r.span, "response returned");
407                        Ok(r.response)
408                    }
409                    Err(_) => Err(E::from(Error::ClientDropped)),
410                },
411                Err(_) => Err(E::from(Error::TransportFull)),
412            }
413        })
414    }
415}
416
417impl<T, E, Request> tower::load::Load for Client<T, E, Request>
418where
419    T: Sink<Request> + TryStream,
420{
421    type Metric = usize;
422
423    fn load(&self) -> Self::Metric {
424        self.in_flight.load(atomic::Ordering::Acquire)
425    }
426}
427
428// ===== impl SpawnError =====
429
430impl<T> fmt::Display for SpawnError<T>
431where
432    T: fmt::Debug,
433{
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        match *self {
436            SpawnError::SpawnFailed => f.pad("error spawning multiplex client"),
437            SpawnError::Inner(_) => f.pad("error making new multiplex transport"),
438        }
439    }
440}
441
442impl<T> error::Error for SpawnError<T>
443where
444    T: error::Error + 'static,
445{
446    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
447        match *self {
448            SpawnError::SpawnFailed => None,
449            SpawnError::Inner(ref te) => Some(te),
450        }
451    }
452}