tower_limit/rate/
service.rs1use super::Rate;
2use futures_core::ready;
3use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::time::{Delay, Instant};
9use tower_service::Service;
10
11#[derive(Debug)]
14pub struct RateLimit<T> {
15 inner: T,
16 rate: Rate,
17 state: State,
18}
19
20#[derive(Debug)]
21enum State {
22 Limited(Delay),
24 Ready { until: Instant, rem: u64 },
25}
26
27impl<T> RateLimit<T> {
28 pub fn new(inner: T, rate: Rate) -> Self {
30 let state = State::Ready {
31 until: Instant::now(),
32 rem: rate.num(),
33 };
34
35 RateLimit {
36 inner,
37 rate,
38 state: state,
39 }
40 }
41
42 pub fn get_ref(&self) -> &T {
44 &self.inner
45 }
46
47 pub fn get_mut(&mut self) -> &mut T {
49 &mut self.inner
50 }
51
52 pub fn into_inner(self) -> T {
54 self.inner
55 }
56}
57
58impl<S, Request> Service<Request> for RateLimit<S>
59where
60 S: Service<Request>,
61{
62 type Response = S::Response;
63 type Error = S::Error;
64 type Future = S::Future;
65
66 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
67 match self.state {
68 State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
69 State::Limited(ref mut sleep) => {
70 ready!(Pin::new(sleep).poll(cx));
71 }
72 }
73
74 self.state = State::Ready {
75 until: Instant::now() + self.rate.per(),
76 rem: self.rate.num(),
77 };
78
79 Poll::Ready(ready!(self.inner.poll_ready(cx)))
80 }
81
82 fn call(&mut self, request: Request) -> Self::Future {
83 match self.state {
84 State::Ready { mut until, mut rem } => {
85 let now = Instant::now();
86
87 if now >= until {
89 until = now + self.rate.per();
90 let rem = self.rate.num();
91
92 self.state = State::Ready { until, rem }
93 }
94
95 if rem > 1 {
96 rem -= 1;
97 self.state = State::Ready { until, rem };
98 } else {
99 let sleep = tokio::time::delay_until(until);
101 self.state = State::Limited(sleep);
102 }
103
104 self.inner.call(request)
106 }
107 State::Limited(..) => panic!("service not ready; poll_ready must be called first"),
108 }
109 }
110}
111
112impl<S> tower_load::Load for RateLimit<S>
113where
114 S: tower_load::Load,
115{
116 type Metric = S::Metric;
117 fn load(&self) -> Self::Metric {
118 self.inner.load()
119 }
120}