tower_limit/concurrency/
service.rs

1use super::future::ResponseFuture;
2
3use tower_service::Service;
4
5use super::sync::semaphore::{self, Semaphore};
6use futures_core::ready;
7use std::sync::Arc;
8use std::task::{Context, Poll};
9
10/// Enforces a limit on the concurrent number of requests the underlying
11/// service can handle.
12#[derive(Debug)]
13pub struct ConcurrencyLimit<T> {
14    inner: T,
15    limit: Limit,
16}
17
18#[derive(Debug)]
19struct Limit {
20    semaphore: Arc<Semaphore>,
21    permit: semaphore::Permit,
22}
23
24impl<T> ConcurrencyLimit<T> {
25    /// Create a new concurrency limiter.
26    pub fn new(inner: T, max: usize) -> Self {
27        ConcurrencyLimit {
28            inner,
29            limit: Limit {
30                semaphore: Arc::new(Semaphore::new(max)),
31                permit: semaphore::Permit::new(),
32            },
33        }
34    }
35
36    /// Get a reference to the inner service
37    pub fn get_ref(&self) -> &T {
38        &self.inner
39    }
40
41    /// Get a mutable reference to the inner service
42    pub fn get_mut(&mut self) -> &mut T {
43        &mut self.inner
44    }
45
46    /// Consume `self`, returning the inner service
47    pub fn into_inner(self) -> T {
48        self.inner
49    }
50}
51
52impl<S, Request> Service<Request> for ConcurrencyLimit<S>
53where
54    S: Service<Request>,
55{
56    type Response = S::Response;
57    type Error = S::Error;
58    type Future = ResponseFuture<S::Future>;
59
60    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61        ready!(self.limit.permit.poll_acquire(cx, &self.limit.semaphore))
62            .expect("poll_acquire after semaphore closed ");
63
64        Poll::Ready(ready!(self.inner.poll_ready(cx)))
65    }
66
67    fn call(&mut self, request: Request) -> Self::Future {
68        // Make sure a permit has been acquired
69        if self
70            .limit
71            .permit
72            .try_acquire(&self.limit.semaphore)
73            .is_err()
74        {
75            panic!("max requests in-flight; poll_ready must be called first");
76        }
77
78        // Call the inner service
79        let future = self.inner.call(request);
80
81        // Forget the permit, the permit will be returned when
82        // `future::ResponseFuture` is dropped.
83        self.limit.permit.forget();
84
85        ResponseFuture::new(future, self.limit.semaphore.clone())
86    }
87}
88
89impl<S> tower_load::Load for ConcurrencyLimit<S>
90where
91    S: tower_load::Load,
92{
93    type Metric = S::Metric;
94    fn load(&self) -> Self::Metric {
95        self.inner.load()
96    }
97}
98
99impl<S> Clone for ConcurrencyLimit<S>
100where
101    S: Clone,
102{
103    fn clone(&self) -> ConcurrencyLimit<S> {
104        ConcurrencyLimit {
105            inner: self.inner.clone(),
106            limit: Limit {
107                semaphore: self.limit.semaphore.clone(),
108                permit: semaphore::Permit::new(),
109            },
110        }
111    }
112}
113
114impl Drop for Limit {
115    fn drop(&mut self) {
116        self.permit.release(&self.semaphore);
117    }
118}