1use super::{Instrument, InstrumentFuture, NoInstrument};
4use crate::Load;
5use futures_core::ready;
6use pin_project::pin_project;
7use std::sync::Arc;
8use std::{
9 pin::Pin,
10 task::{Context, Poll},
11};
12use tower_discover::{Change, Discover};
13use tower_service::Service;
14
15#[derive(Debug)]
17pub struct PendingRequests<S, I = NoInstrument> {
18 service: S,
19 ref_count: RefCount,
20 instrument: I,
21}
22
23#[derive(Clone, Debug, Default)]
26struct RefCount(Arc<()>);
27
28#[pin_project]
30#[derive(Debug)]
31pub struct PendingRequestsDiscover<D, I = NoInstrument> {
32 #[pin]
33 discover: D,
34 instrument: I,
35}
36
37#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
39pub struct Count(usize);
40
41#[derive(Debug)]
43pub struct Handle(RefCount);
44
45impl<S, I> PendingRequests<S, I> {
48 fn new(service: S, instrument: I) -> Self {
49 Self {
50 service,
51 instrument,
52 ref_count: RefCount::default(),
53 }
54 }
55
56 fn handle(&self) -> Handle {
57 Handle(self.ref_count.clone())
58 }
59}
60
61impl<S, I> Load for PendingRequests<S, I> {
62 type Metric = Count;
63
64 fn load(&self) -> Count {
65 Count(self.ref_count.ref_count() - 1)
67 }
68}
69
70impl<S, I, Request> Service<Request> for PendingRequests<S, I>
71where
72 S: Service<Request>,
73 I: Instrument<Handle, S::Response>,
74{
75 type Response = I::Output;
76 type Error = S::Error;
77 type Future = InstrumentFuture<S::Future, I, Handle>;
78
79 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 self.service.poll_ready(cx)
81 }
82
83 fn call(&mut self, req: Request) -> Self::Future {
84 InstrumentFuture::new(
85 self.instrument.clone(),
86 self.handle(),
87 self.service.call(req),
88 )
89 }
90}
91
92impl<D, I> PendingRequestsDiscover<D, I> {
95 pub fn new<Request>(discover: D, instrument: I) -> Self
97 where
98 D: Discover,
99 D::Service: Service<Request>,
100 I: Instrument<Handle, <D::Service as Service<Request>>::Response>,
101 {
102 Self {
103 discover,
104 instrument,
105 }
106 }
107}
108
109impl<D, I> Discover for PendingRequestsDiscover<D, I>
110where
111 D: Discover,
112 I: Clone,
113{
114 type Key = D::Key;
115 type Service = PendingRequests<D::Service, I>;
116 type Error = D::Error;
117
118 fn poll_discover(
120 self: Pin<&mut Self>,
121 cx: &mut Context<'_>,
122 ) -> Poll<Result<Change<D::Key, Self::Service>, D::Error>> {
123 use self::Change::*;
124
125 let this = self.project();
126 let change = match ready!(this.discover.poll_discover(cx))? {
127 Insert(k, svc) => Insert(k, PendingRequests::new(svc, this.instrument.clone())),
128 Remove(k) => Remove(k),
129 };
130
131 Poll::Ready(Ok(change))
132 }
133}
134
135impl RefCount {
138 pub(crate) fn ref_count(&self) -> usize {
139 Arc::strong_count(&self.0)
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use futures_util::future;
147 use std::task::{Context, Poll};
148
149 struct Svc;
150 impl Service<()> for Svc {
151 type Response = ();
152 type Error = ();
153 type Future = future::Ready<Result<(), ()>>;
154
155 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
156 Poll::Ready(Ok(()))
157 }
158
159 fn call(&mut self, (): ()) -> Self::Future {
160 future::ok(())
161 }
162 }
163
164 #[test]
165 fn default() {
166 let mut svc = PendingRequests::new(Svc, NoInstrument);
167 assert_eq!(svc.load(), Count(0));
168
169 let rsp0 = svc.call(());
170 assert_eq!(svc.load(), Count(1));
171
172 let rsp1 = svc.call(());
173 assert_eq!(svc.load(), Count(2));
174
175 let () = tokio_test::block_on(rsp0).unwrap();
176 assert_eq!(svc.load(), Count(1));
177
178 let () = tokio_test::block_on(rsp1).unwrap();
179 assert_eq!(svc.load(), Count(0));
180 }
181
182 #[test]
183 fn instrumented() {
184 #[derive(Clone)]
185 struct IntoHandle;
186 impl Instrument<Handle, ()> for IntoHandle {
187 type Output = Handle;
188 fn instrument(&self, i: Handle, (): ()) -> Handle {
189 i
190 }
191 }
192
193 let mut svc = PendingRequests::new(Svc, IntoHandle);
194 assert_eq!(svc.load(), Count(0));
195
196 let rsp = svc.call(());
197 assert_eq!(svc.load(), Count(1));
198 let i0 = tokio_test::block_on(rsp).unwrap();
199 assert_eq!(svc.load(), Count(1));
200
201 let rsp = svc.call(());
202 assert_eq!(svc.load(), Count(2));
203 let i1 = tokio_test::block_on(rsp).unwrap();
204 assert_eq!(svc.load(), Count(2));
205
206 drop(i1);
207 assert_eq!(svc.load(), Count(1));
208
209 drop(i0);
210 assert_eq!(svc.load(), Count(0));
211 }
212}