tower_limit/rate/
service.rs

1use super::Rate;
2use futures_core::ready;
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::time::{Delay, Instant};
9use tower_service::Service;
10
11/// Enforces a rate limit on the number of requests the underlying
12/// service can handle over a period of time.
13#[derive(Debug)]
14pub struct RateLimit<T> {
15    inner: T,
16    rate: Rate,
17    state: State,
18}
19
20#[derive(Debug)]
21enum State {
22    // The service has hit its limit
23    Limited(Delay),
24    Ready { until: Instant, rem: u64 },
25}
26
27impl<T> RateLimit<T> {
28    /// Create a new rate limiter
29    pub fn new(inner: T, rate: Rate) -> Self {
30        let state = State::Ready {
31            until: Instant::now(),
32            rem: rate.num(),
33        };
34
35        RateLimit {
36            inner,
37            rate,
38            state: state,
39        }
40    }
41
42    /// Get a reference to the inner service
43    pub fn get_ref(&self) -> &T {
44        &self.inner
45    }
46
47    /// Get a mutable reference to the inner service
48    pub fn get_mut(&mut self) -> &mut T {
49        &mut self.inner
50    }
51
52    /// Consume `self`, returning the inner service
53    pub fn into_inner(self) -> T {
54        self.inner
55    }
56}
57
58impl<S, Request> Service<Request> for RateLimit<S>
59where
60    S: Service<Request>,
61{
62    type Response = S::Response;
63    type Error = S::Error;
64    type Future = S::Future;
65
66    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67        match self.state {
68            State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
69            State::Limited(ref mut sleep) => {
70                ready!(Pin::new(sleep).poll(cx));
71            }
72        }
73
74        self.state = State::Ready {
75            until: Instant::now() + self.rate.per(),
76            rem: self.rate.num(),
77        };
78
79        Poll::Ready(ready!(self.inner.poll_ready(cx)))
80    }
81
82    fn call(&mut self, request: Request) -> Self::Future {
83        match self.state {
84            State::Ready { mut until, mut rem } => {
85                let now = Instant::now();
86
87                // If the period has elapsed, reset it.
88                if now >= until {
89                    until = now + self.rate.per();
90                    let rem = self.rate.num();
91
92                    self.state = State::Ready { until, rem }
93                }
94
95                if rem > 1 {
96                    rem -= 1;
97                    self.state = State::Ready { until, rem };
98                } else {
99                    // The service is disabled until further notice
100                    let sleep = tokio::time::delay_until(until);
101                    self.state = State::Limited(sleep);
102                }
103
104                // Call the inner future
105                self.inner.call(request)
106            }
107            State::Limited(..) => panic!("service not ready; poll_ready must be called first"),
108        }
109    }
110}
111
112impl<S> tower_load::Load for RateLimit<S>
113where
114    S: tower_load::Load,
115{
116    type Metric = S::Metric;
117    fn load(&self) -> Self::Metric {
118        self.inner.load()
119    }
120}