tower_load/
pending_requests.rs

1//! A `Load` implementation that uses the count of in-flight requests.
2
3use super::{Instrument, InstrumentFuture, NoInstrument};
4use crate::Load;
5use futures_core::ready;
6use pin_project::pin_project;
7use std::sync::Arc;
8use std::{
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tower_discover::{Change, Discover};
13use tower_service::Service;
14
15/// Expresses load based on the number of currently-pending requests.
16#[derive(Debug)]
17pub struct PendingRequests<S, I = NoInstrument> {
18    service: S,
19    ref_count: RefCount,
20    instrument: I,
21}
22
23/// Shared between instances of `PendingRequests` and `Handle` to track active
24/// references.
25#[derive(Clone, Debug, Default)]
26struct RefCount(Arc<()>);
27
28/// Wraps `inner`'s services with `PendingRequests`.
29#[pin_project]
30#[derive(Debug)]
31pub struct PendingRequestsDiscover<D, I = NoInstrument> {
32    #[pin]
33    discover: D,
34    instrument: I,
35}
36
37/// Represents the number of currently-pending requests to a given service.
38#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
39pub struct Count(usize);
40
41/// Tracks an in-flight request by reference count.
42#[derive(Debug)]
43pub struct Handle(RefCount);
44
45// ===== impl PendingRequests =====
46
47impl<S, I> PendingRequests<S, I> {
48    fn new(service: S, instrument: I) -> Self {
49        Self {
50            service,
51            instrument,
52            ref_count: RefCount::default(),
53        }
54    }
55
56    fn handle(&self) -> Handle {
57        Handle(self.ref_count.clone())
58    }
59}
60
61impl<S, I> Load for PendingRequests<S, I> {
62    type Metric = Count;
63
64    fn load(&self) -> Count {
65        // Count the number of references that aren't `self`.
66        Count(self.ref_count.ref_count() - 1)
67    }
68}
69
70impl<S, I, Request> Service<Request> for PendingRequests<S, I>
71where
72    S: Service<Request>,
73    I: Instrument<Handle, S::Response>,
74{
75    type Response = I::Output;
76    type Error = S::Error;
77    type Future = InstrumentFuture<S::Future, I, Handle>;
78
79    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80        self.service.poll_ready(cx)
81    }
82
83    fn call(&mut self, req: Request) -> Self::Future {
84        InstrumentFuture::new(
85            self.instrument.clone(),
86            self.handle(),
87            self.service.call(req),
88        )
89    }
90}
91
92// ===== impl PendingRequestsDiscover =====
93
94impl<D, I> PendingRequestsDiscover<D, I> {
95    /// Wraps a `Discover``, wrapping all of its services with `PendingRequests`.
96    pub fn new<Request>(discover: D, instrument: I) -> Self
97    where
98        D: Discover,
99        D::Service: Service<Request>,
100        I: Instrument<Handle, <D::Service as Service<Request>>::Response>,
101    {
102        Self {
103            discover,
104            instrument,
105        }
106    }
107}
108
109impl<D, I> Discover for PendingRequestsDiscover<D, I>
110where
111    D: Discover,
112    I: Clone,
113{
114    type Key = D::Key;
115    type Service = PendingRequests<D::Service, I>;
116    type Error = D::Error;
117
118    /// Yields the next discovery change set.
119    fn poll_discover(
120        self: Pin<&mut Self>,
121        cx: &mut Context<'_>,
122    ) -> Poll<Result<Change<D::Key, Self::Service>, D::Error>> {
123        use self::Change::*;
124
125        let this = self.project();
126        let change = match ready!(this.discover.poll_discover(cx))? {
127            Insert(k, svc) => Insert(k, PendingRequests::new(svc, this.instrument.clone())),
128            Remove(k) => Remove(k),
129        };
130
131        Poll::Ready(Ok(change))
132    }
133}
134
135// ==== RefCount ====
136
137impl RefCount {
138    pub(crate) fn ref_count(&self) -> usize {
139        Arc::strong_count(&self.0)
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146    use futures_util::future;
147    use std::task::{Context, Poll};
148
149    struct Svc;
150    impl Service<()> for Svc {
151        type Response = ();
152        type Error = ();
153        type Future = future::Ready<Result<(), ()>>;
154
155        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
156            Poll::Ready(Ok(()))
157        }
158
159        fn call(&mut self, (): ()) -> Self::Future {
160            future::ok(())
161        }
162    }
163
164    #[test]
165    fn default() {
166        let mut svc = PendingRequests::new(Svc, NoInstrument);
167        assert_eq!(svc.load(), Count(0));
168
169        let rsp0 = svc.call(());
170        assert_eq!(svc.load(), Count(1));
171
172        let rsp1 = svc.call(());
173        assert_eq!(svc.load(), Count(2));
174
175        let () = tokio_test::block_on(rsp0).unwrap();
176        assert_eq!(svc.load(), Count(1));
177
178        let () = tokio_test::block_on(rsp1).unwrap();
179        assert_eq!(svc.load(), Count(0));
180    }
181
182    #[test]
183    fn instrumented() {
184        #[derive(Clone)]
185        struct IntoHandle;
186        impl Instrument<Handle, ()> for IntoHandle {
187            type Output = Handle;
188            fn instrument(&self, i: Handle, (): ()) -> Handle {
189                i
190            }
191        }
192
193        let mut svc = PendingRequests::new(Svc, IntoHandle);
194        assert_eq!(svc.load(), Count(0));
195
196        let rsp = svc.call(());
197        assert_eq!(svc.load(), Count(1));
198        let i0 = tokio_test::block_on(rsp).unwrap();
199        assert_eq!(svc.load(), Count(1));
200
201        let rsp = svc.call(());
202        assert_eq!(svc.load(), Count(2));
203        let i1 = tokio_test::block_on(rsp).unwrap();
204        assert_eq!(svc.load(), Count(2));
205
206        drop(i1);
207        assert_eq!(svc.load(), Count(1));
208
209        drop(i0);
210        assert_eq!(svc.load(), Count(0));
211    }
212}