1use super::{Instrument, InstrumentFuture, NoInstrument};
4use crate::Load;
5use futures_core::ready;
6use log::trace;
7use pin_project::pin_project;
8use std::{
9 pin::Pin,
10 task::{Context, Poll},
11};
12use std::{
13 sync::{Arc, Mutex},
14 time::Duration,
15};
16use tokio::time::Instant;
17use tower_discover::{Change, Discover};
18use tower_service::Service;
19
20#[derive(Debug)]
46pub struct PeakEwma<S, I = NoInstrument> {
47 service: S,
48 decay_ns: f64,
49 rtt_estimate: Arc<Mutex<RttEstimate>>,
50 instrument: I,
51}
52
53#[pin_project]
55#[derive(Debug)]
56pub struct PeakEwmaDiscover<D, I = NoInstrument> {
57 #[pin]
58 discover: D,
59 decay_ns: f64,
60 default_rtt: Duration,
61 instrument: I,
62}
63
64#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
69pub struct Cost(f64);
70
71#[derive(Debug)]
73pub struct Handle {
74 sent_at: Instant,
75 decay_ns: f64,
76 rtt_estimate: Arc<Mutex<RttEstimate>>,
77}
78
79#[derive(Debug)]
81struct RttEstimate {
82 update_at: Instant,
83 rtt_ns: f64,
84}
85
86const NANOS_PER_MILLI: f64 = 1_000_000.0;
87
88impl<D, I> PeakEwmaDiscover<D, I> {
91 pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, instrument: I) -> Self
99 where
100 D: Discover,
101 D::Service: Service<Request>,
102 I: Instrument<Handle, <D::Service as Service<Request>>::Response>,
103 {
104 PeakEwmaDiscover {
105 discover,
106 decay_ns: nanos(decay),
107 default_rtt,
108 instrument,
109 }
110 }
111}
112
113impl<D, I> Discover for PeakEwmaDiscover<D, I>
114where
115 D: Discover,
116 I: Clone,
117{
118 type Key = D::Key;
119 type Service = PeakEwma<D::Service, I>;
120 type Error = D::Error;
121
122 fn poll_discover(
123 self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 ) -> Poll<Result<Change<D::Key, Self::Service>, D::Error>> {
126 let this = self.project();
127 let change = match ready!(this.discover.poll_discover(cx))? {
128 Change::Remove(k) => Change::Remove(k),
129 Change::Insert(k, svc) => {
130 let peak_ewma = PeakEwma::new(
131 svc,
132 *this.default_rtt,
133 *this.decay_ns,
134 this.instrument.clone(),
135 );
136 Change::Insert(k, peak_ewma)
137 }
138 };
139
140 Poll::Ready(Ok(change))
141 }
142}
143
144impl<S, I> PeakEwma<S, I> {
147 fn new(service: S, default_rtt: Duration, decay_ns: f64, instrument: I) -> Self {
148 Self {
149 service,
150 decay_ns,
151 rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
152 instrument,
153 }
154 }
155
156 fn handle(&self) -> Handle {
157 Handle {
158 decay_ns: self.decay_ns,
159 sent_at: Instant::now(),
160 rtt_estimate: self.rtt_estimate.clone(),
161 }
162 }
163}
164
165impl<S, I, Request> Service<Request> for PeakEwma<S, I>
166where
167 S: Service<Request>,
168 I: Instrument<Handle, S::Response>,
169{
170 type Response = I::Output;
171 type Error = S::Error;
172 type Future = InstrumentFuture<S::Future, I, Handle>;
173
174 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175 self.service.poll_ready(cx)
176 }
177
178 fn call(&mut self, req: Request) -> Self::Future {
179 InstrumentFuture::new(
180 self.instrument.clone(),
181 self.handle(),
182 self.service.call(req),
183 )
184 }
185}
186
187impl<S, I> Load for PeakEwma<S, I> {
188 type Metric = Cost;
189
190 fn load(&self) -> Self::Metric {
191 let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
192
193 let estimate = self.update_estimate();
196
197 let cost = Cost(estimate * f64::from(pending + 1));
198 trace!(
199 "load estimate={:.0}ms pending={} cost={:?}",
200 estimate / NANOS_PER_MILLI,
201 pending,
202 cost,
203 );
204 cost
205 }
206}
207
208impl<S, I> PeakEwma<S, I> {
209 fn update_estimate(&self) -> f64 {
210 let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
211 rtt.decay(self.decay_ns)
212 }
213}
214
215impl RttEstimate {
218 fn new(rtt_ns: f64) -> Self {
219 debug_assert!(0.0 < rtt_ns, "rtt must be positive");
220 Self {
221 rtt_ns,
222 update_at: Instant::now(),
223 }
224 }
225
226 fn decay(&mut self, decay_ns: f64) -> f64 {
228 let now = Instant::now();
230 self.update(now, now, decay_ns)
231 }
232
233 fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
237 debug_assert!(
238 sent_at <= recv_at,
239 "recv_at={:?} after sent_at={:?}",
240 recv_at,
241 sent_at
242 );
243 let rtt = nanos(recv_at - sent_at);
244
245 let now = Instant::now();
246 debug_assert!(
247 self.update_at <= now,
248 "update_at={:?} in the future",
249 self.update_at
250 );
251
252 self.rtt_ns = if self.rtt_ns < rtt {
253 trace!(
256 "update peak rtt={}ms prior={}ms",
257 rtt / NANOS_PER_MILLI,
258 self.rtt_ns / NANOS_PER_MILLI,
259 );
260 rtt
261 } else {
262 let elapsed = nanos(now - self.update_at);
267 let decay = (-elapsed / decay_ns).exp();
268 let recency = 1.0 - decay;
269 let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
270 trace!(
271 "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
272 rtt / NANOS_PER_MILLI,
273 self.rtt_ns - next_estimate,
274 next_estimate / NANOS_PER_MILLI,
275 );
276 next_estimate
277 };
278 self.update_at = now;
279
280 self.rtt_ns
281 }
282}
283
284impl Drop for Handle {
287 fn drop(&mut self) {
288 let recv_at = Instant::now();
289
290 if let Ok(mut rtt) = self.rtt_estimate.lock() {
291 rtt.update(self.sent_at, recv_at, self.decay_ns);
292 }
293 }
294}
295
296fn nanos(d: Duration) -> f64 {
303 const NANOS_PER_SEC: u64 = 1_000_000_000;
304 let n = f64::from(d.subsec_nanos());
305 let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
306 n + s
307}
308
309#[cfg(test)]
310mod tests {
311 use futures_util::future;
312 use std::time::Duration;
313 use tokio::time;
314 use tokio_test::{assert_ready, assert_ready_ok, task};
315
316 use super::*;
317
318 struct Svc;
319 impl Service<()> for Svc {
320 type Response = ();
321 type Error = ();
322 type Future = future::Ready<Result<(), ()>>;
323
324 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
325 Poll::Ready(Ok(()))
326 }
327
328 fn call(&mut self, (): ()) -> Self::Future {
329 future::ok(())
330 }
331 }
332
333 #[tokio::test]
336 async fn default_decay() {
337 time::pause();
338
339 let svc = PeakEwma::new(
340 Svc,
341 Duration::from_millis(10),
342 NANOS_PER_MILLI * 1_000.0,
343 NoInstrument,
344 );
345 let Cost(load) = svc.load();
346 assert_eq!(load, 10.0 * NANOS_PER_MILLI);
347
348 time::advance(Duration::from_millis(100)).await;
349 let Cost(load) = svc.load();
350 assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
351
352 time::advance(Duration::from_millis(100)).await;
353 let Cost(load) = svc.load();
354 assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
355 }
356
357 #[tokio::test]
360 async fn compound_decay() {
361 time::pause();
362
363 let mut svc = PeakEwma::new(
364 Svc,
365 Duration::from_millis(20),
366 NANOS_PER_MILLI * 1_000.0,
367 NoInstrument,
368 );
369 assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
370
371 time::advance(Duration::from_millis(100)).await;
372 let mut rsp0 = task::spawn(svc.call(()));
373 assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
374
375 time::advance(Duration::from_millis(100)).await;
376 let mut rsp1 = task::spawn(svc.call(()));
377 assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
378
379 time::advance(Duration::from_millis(100)).await;
380 let () = assert_ready_ok!(rsp0.poll());
381 assert_eq!(svc.load(), Cost(400_000_000.0));
382
383 time::advance(Duration::from_millis(100)).await;
384 let () = assert_ready_ok!(rsp1.poll());
385 assert_eq!(svc.load(), Cost(200_000_000.0));
386
387 time::advance(Duration::from_secs(1)).await;
389 assert!(svc.load() < Cost(100_000_000.0));
390
391 time::advance(Duration::from_secs(10)).await;
392 assert!(svc.load() < Cost(100_000.0));
393 }
394
395 #[test]
396 fn nanos() {
397 assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
398 assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
399 assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
400 assert_eq!(
401 super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
402 18446744074709553000.0
403 );
404 }
405}