tower_load/
peak_ewma.rs

1//! A `Load` implementation that PeakEWMA on response latency.
2
3use super::{Instrument, InstrumentFuture, NoInstrument};
4use crate::Load;
5use futures_core::ready;
6use log::trace;
7use pin_project::pin_project;
8use std::{
9    pin::Pin,
10    task::{Context, Poll},
11};
12use std::{
13    sync::{Arc, Mutex},
14    time::Duration,
15};
16use tokio::time::Instant;
17use tower_discover::{Change, Discover};
18use tower_service::Service;
19
20/// Wraps an `S`-typed Service with Peak-EWMA load measurement.
21///
22/// `PeakEwma` implements `Load` with the `Cost` metric that estimates the amount of
23/// pending work to an endpoint. Work is calculated by multiplying the
24/// exponentially-weighted moving average (EWMA) of response latencies by the number of
25/// pending requests. The Peak-EWMA algorithm is designed to be especially sensitive to
26/// worst-case latencies. Over time, the peak latency value decays towards the moving
27/// average of latencies to the endpoint.
28///
29/// As requests are sent to the underlying service, an `I`-typed instrumentation strategy
30/// is used to track responses to measure latency in an application-specific way. The
31/// default strategy measures latency as the elapsed time from the request being issued to
32/// the underlying service to the response future being satisfied (or dropped).
33///
34/// When no latency information has been measured for an endpoint, an arbitrary default
35/// RTT of 1 second is used to prevent the endpoint from being overloaded before a
36/// meaningful baseline can be established..
37///
38/// ## Note
39///
40/// This is derived from [Finagle][finagle], which is distributed under the Apache V2
41/// license. Copyright 2017, Twitter Inc.
42///
43/// [finagle]:
44/// https://github.com/twitter/finagle/blob/9cc08d15216497bb03a1cafda96b7266cfbbcff1/finagle-core/src/main/scala/com/twitter/finagle/loadbalancer/PeakEwma.scala
45#[derive(Debug)]
46pub struct PeakEwma<S, I = NoInstrument> {
47    service: S,
48    decay_ns: f64,
49    rtt_estimate: Arc<Mutex<RttEstimate>>,
50    instrument: I,
51}
52
53/// Wraps a `D`-typed stream of discovery updates with `PeakEwma`.
54#[pin_project]
55#[derive(Debug)]
56pub struct PeakEwmaDiscover<D, I = NoInstrument> {
57    #[pin]
58    discover: D,
59    decay_ns: f64,
60    default_rtt: Duration,
61    instrument: I,
62}
63
64/// Represents the relative cost of communicating with a service.
65///
66/// The underlying value estimates the amount of pending work to a service: the Peak-EWMA
67/// latency estimate multiplied by the number of pending requests.
68#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
69pub struct Cost(f64);
70
71/// Tracks an in-flight request and updates the RTT-estimate on Drop.
72#[derive(Debug)]
73pub struct Handle {
74    sent_at: Instant,
75    decay_ns: f64,
76    rtt_estimate: Arc<Mutex<RttEstimate>>,
77}
78
79/// Holds the current RTT estimate and the last time this value was updated.
80#[derive(Debug)]
81struct RttEstimate {
82    update_at: Instant,
83    rtt_ns: f64,
84}
85
86const NANOS_PER_MILLI: f64 = 1_000_000.0;
87
88// ===== impl PeakEwma =====
89
90impl<D, I> PeakEwmaDiscover<D, I> {
91    /// Wraps a `D`-typed `Discover` so that services have a `PeakEwma` load metric.
92    ///
93    /// The provided `default_rtt` is used as the default RTT estimate for newly
94    /// added services.
95    ///
96    /// They `decay` value determines over what time period a RTT estimate should
97    /// decay.
98    pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, instrument: I) -> Self
99    where
100        D: Discover,
101        D::Service: Service<Request>,
102        I: Instrument<Handle, <D::Service as Service<Request>>::Response>,
103    {
104        PeakEwmaDiscover {
105            discover,
106            decay_ns: nanos(decay),
107            default_rtt,
108            instrument,
109        }
110    }
111}
112
113impl<D, I> Discover for PeakEwmaDiscover<D, I>
114where
115    D: Discover,
116    I: Clone,
117{
118    type Key = D::Key;
119    type Service = PeakEwma<D::Service, I>;
120    type Error = D::Error;
121
122    fn poll_discover(
123        self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125    ) -> Poll<Result<Change<D::Key, Self::Service>, D::Error>> {
126        let this = self.project();
127        let change = match ready!(this.discover.poll_discover(cx))? {
128            Change::Remove(k) => Change::Remove(k),
129            Change::Insert(k, svc) => {
130                let peak_ewma = PeakEwma::new(
131                    svc,
132                    *this.default_rtt,
133                    *this.decay_ns,
134                    this.instrument.clone(),
135                );
136                Change::Insert(k, peak_ewma)
137            }
138        };
139
140        Poll::Ready(Ok(change))
141    }
142}
143
144// ===== impl PeakEwma =====
145
146impl<S, I> PeakEwma<S, I> {
147    fn new(service: S, default_rtt: Duration, decay_ns: f64, instrument: I) -> Self {
148        Self {
149            service,
150            decay_ns,
151            rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
152            instrument,
153        }
154    }
155
156    fn handle(&self) -> Handle {
157        Handle {
158            decay_ns: self.decay_ns,
159            sent_at: Instant::now(),
160            rtt_estimate: self.rtt_estimate.clone(),
161        }
162    }
163}
164
165impl<S, I, Request> Service<Request> for PeakEwma<S, I>
166where
167    S: Service<Request>,
168    I: Instrument<Handle, S::Response>,
169{
170    type Response = I::Output;
171    type Error = S::Error;
172    type Future = InstrumentFuture<S::Future, I, Handle>;
173
174    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
175        self.service.poll_ready(cx)
176    }
177
178    fn call(&mut self, req: Request) -> Self::Future {
179        InstrumentFuture::new(
180            self.instrument.clone(),
181            self.handle(),
182            self.service.call(req),
183        )
184    }
185}
186
187impl<S, I> Load for PeakEwma<S, I> {
188    type Metric = Cost;
189
190    fn load(&self) -> Self::Metric {
191        let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
192
193        // Update the RTT estimate to account for decay since the last update.
194        // If an estimate has not been established, a default is provided
195        let estimate = self.update_estimate();
196
197        let cost = Cost(estimate * f64::from(pending + 1));
198        trace!(
199            "load estimate={:.0}ms pending={} cost={:?}",
200            estimate / NANOS_PER_MILLI,
201            pending,
202            cost,
203        );
204        cost
205    }
206}
207
208impl<S, I> PeakEwma<S, I> {
209    fn update_estimate(&self) -> f64 {
210        let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
211        rtt.decay(self.decay_ns)
212    }
213}
214
215// ===== impl RttEstimate =====
216
217impl RttEstimate {
218    fn new(rtt_ns: f64) -> Self {
219        debug_assert!(0.0 < rtt_ns, "rtt must be positive");
220        Self {
221            rtt_ns,
222            update_at: Instant::now(),
223        }
224    }
225
226    /// Decays the RTT estimate with a decay period of `decay_ns`.
227    fn decay(&mut self, decay_ns: f64) -> f64 {
228        // Updates with a 0 duration so that the estimate decays towards 0.
229        let now = Instant::now();
230        self.update(now, now, decay_ns)
231    }
232
233    /// Updates the Peak-EWMA RTT estimate.
234    ///
235    /// The elapsed time from `sent_at` to `recv_at` is added
236    fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
237        debug_assert!(
238            sent_at <= recv_at,
239            "recv_at={:?} after sent_at={:?}",
240            recv_at,
241            sent_at
242        );
243        let rtt = nanos(recv_at - sent_at);
244
245        let now = Instant::now();
246        debug_assert!(
247            self.update_at <= now,
248            "update_at={:?} in the future",
249            self.update_at
250        );
251
252        self.rtt_ns = if self.rtt_ns < rtt {
253            // For Peak-EWMA, always use the worst-case (peak) value as the estimate for
254            // subsequent requests.
255            trace!(
256                "update peak rtt={}ms prior={}ms",
257                rtt / NANOS_PER_MILLI,
258                self.rtt_ns / NANOS_PER_MILLI,
259            );
260            rtt
261        } else {
262            // When an RTT is observed that is less than the estimated RTT, we decay the
263            // prior estimate according to how much time has elapsed since the last
264            // update. The inverse of the decay is used to scale the estimate towards the
265            // observed RTT value.
266            let elapsed = nanos(now - self.update_at);
267            let decay = (-elapsed / decay_ns).exp();
268            let recency = 1.0 - decay;
269            let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
270            trace!(
271                "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
272                rtt / NANOS_PER_MILLI,
273                self.rtt_ns - next_estimate,
274                next_estimate / NANOS_PER_MILLI,
275            );
276            next_estimate
277        };
278        self.update_at = now;
279
280        self.rtt_ns
281    }
282}
283
284// ===== impl Handle =====
285
286impl Drop for Handle {
287    fn drop(&mut self) {
288        let recv_at = Instant::now();
289
290        if let Ok(mut rtt) = self.rtt_estimate.lock() {
291            rtt.update(self.sent_at, recv_at, self.decay_ns);
292        }
293    }
294}
295
296// ===== impl Cost =====
297
298// Utility that converts durations to nanos in f64.
299//
300// Due to a lossy transformation, the maximum value that can be represented is ~585 years,
301// which, I hope, is more than enough to represent request latencies.
302fn nanos(d: Duration) -> f64 {
303    const NANOS_PER_SEC: u64 = 1_000_000_000;
304    let n = f64::from(d.subsec_nanos());
305    let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
306    n + s
307}
308
309#[cfg(test)]
310mod tests {
311    use futures_util::future;
312    use std::time::Duration;
313    use tokio::time;
314    use tokio_test::{assert_ready, assert_ready_ok, task};
315
316    use super::*;
317
318    struct Svc;
319    impl Service<()> for Svc {
320        type Response = ();
321        type Error = ();
322        type Future = future::Ready<Result<(), ()>>;
323
324        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
325            Poll::Ready(Ok(()))
326        }
327
328        fn call(&mut self, (): ()) -> Self::Future {
329            future::ok(())
330        }
331    }
332
333    /// The default RTT estimate decays, so that new nodes are considered if the
334    /// default RTT is too high.
335    #[tokio::test]
336    async fn default_decay() {
337        time::pause();
338
339        let svc = PeakEwma::new(
340            Svc,
341            Duration::from_millis(10),
342            NANOS_PER_MILLI * 1_000.0,
343            NoInstrument,
344        );
345        let Cost(load) = svc.load();
346        assert_eq!(load, 10.0 * NANOS_PER_MILLI);
347
348        time::advance(Duration::from_millis(100)).await;
349        let Cost(load) = svc.load();
350        assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
351
352        time::advance(Duration::from_millis(100)).await;
353        let Cost(load) = svc.load();
354        assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
355    }
356
357    // /// The default RTT estimate decays, so that new nodes are considered if the
358    // /// default RTT is too high.
359    #[tokio::test]
360    async fn compound_decay() {
361        time::pause();
362
363        let mut svc = PeakEwma::new(
364            Svc,
365            Duration::from_millis(20),
366            NANOS_PER_MILLI * 1_000.0,
367            NoInstrument,
368        );
369        assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
370
371        time::advance(Duration::from_millis(100)).await;
372        let mut rsp0 = task::spawn(svc.call(()));
373        assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
374
375        time::advance(Duration::from_millis(100)).await;
376        let mut rsp1 = task::spawn(svc.call(()));
377        assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
378
379        time::advance(Duration::from_millis(100)).await;
380        let () = assert_ready_ok!(rsp0.poll());
381        assert_eq!(svc.load(), Cost(400_000_000.0));
382
383        time::advance(Duration::from_millis(100)).await;
384        let () = assert_ready_ok!(rsp1.poll());
385        assert_eq!(svc.load(), Cost(200_000_000.0));
386
387        // Check that values decay as time elapses
388        time::advance(Duration::from_secs(1)).await;
389        assert!(svc.load() < Cost(100_000_000.0));
390
391        time::advance(Duration::from_secs(10)).await;
392        assert!(svc.load() < Cost(100_000.0));
393    }
394
395    #[test]
396    fn nanos() {
397        assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
398        assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
399        assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
400        assert_eq!(
401            super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
402            18446744074709553000.0
403        );
404    }
405}