async_rate_limit/token_bucket.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
use crate::limiters::VariableCostRateLimiter;
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedSemaphorePermit, Semaphore};
use tokio::task::JoinHandle;
use tokio::time::Duration;
/// A classic [token bucket](https://en.wikipedia.org/wiki/Token_bucket) rate limiter that treats non-conformant calls by waiting indefinitely
/// for tokens to become available.
///
/// The token bucket scheme allows users to "burst" up to at most max-tokens during some period, but
/// replaces tokens at a fixed rate so users have some flexibility, but the overall load of the server
/// is still mediated.
///
/// The behavior of it's implementation of [`VariableCostRateLimiter`] can be seen in this
/// [`example`].
///
/// TokenBucketRateLimiters implement [`Clone`], so cloning a rate limiter will clone the
/// underlying state, resulting in two separate rate limiters that share the same underlying state.
/// This allows for structs containing a rate limiter, like an HTTP or REST client, to be [`Clone`]
/// while both instances will respect rate limits.
///
/// *Trying to acquire more than the possible available amount of tokens will deadlock.*
///
/// [`example`]: struct.TokenBucketRateLimiter.html#example-a-variable-cost-api-rate-limit
#[derive(Debug, Clone)]
pub struct TokenBucketRateLimiter {
/// Potentially shared, mutable state required to implement the token bucket scheme.
state: Arc<Mutex<TokenBucketState>>,
}
/// All required state that can be shared among many [`TokenBucketRateLimiters`](`TokenBucketRateLimiter`)
///
/// [`TokenBucketRateLimiter`]s take `Arc<Mutex<TokenBucketState>>` so a single state (including the replenishment task)
/// can be shared among many rate limiters, e.g. when a single API has multiple endpoints, each requiring different costs
/// but counting against the same user rate limit.
#[derive(Debug)]
pub struct TokenBucketState {
/// Starting and max tokens in the bucket. This should NOT be shared outside of [`TokenBucketState`](`TokenBucketState`)
tokens: Arc<Semaphore>,
/// Number of tokens to be replaced every `replace_duration`
replace_amount: usize,
/// Duration after which `replace_amount` tokens are added to the bucket (unless it's full)
replace_duration: Duration,
/// Storage for acquired tokens
acquired_tokens: Arc<Mutex<Vec<OwnedSemaphorePermit>>>,
/// Handle to task that ticks on an interval, replacing `replace_amount` tokens every `replace_duration`
replenish_task: Option<JoinHandle<()>>,
}
impl VariableCostRateLimiter for TokenBucketRateLimiter {
/// # Example: A Variable Cost API Rate Limit
///
/// An API enforces a rate limit by allotting 10 tokens, and replenishes used tokens at a rate of 1 per second.
/// An endpoint being called requires 4 tokens per call.
///
/// ```
/// use tokio::time::{Instant, Duration};
/// use async_rate_limit::limiters::VariableCostRateLimiter;
/// use async_rate_limit::sliding_window::SlidingWindowRateLimiter;
/// use async_rate_limit::token_bucket::{TokenBucketState, TokenBucketRateLimiter};
/// use std::sync::Arc;
/// use tokio::sync::Mutex;
///
/// #[tokio::main]
/// async fn main() -> () {
/// let state = TokenBucketState::new(10, 1, Duration::from_secs(1));
/// let state_mutex = Arc::new(Mutex::new(state));
/// let mut limiter = TokenBucketRateLimiter::new(state_mutex);
///
/// // these calls proceed immediately, using 8 tokens
/// get_something(&mut limiter).await;
/// get_something(&mut limiter).await;
///
/// // this call waits ~2 seconds to acquire additional tokens before proceeding
/// get_something(&mut limiter).await;
/// }
///
/// // note the use of the `VariableCostRateLimiter` trait, rather than the direct type
/// async fn get_something<T>(limiter: &mut T) where T: VariableCostRateLimiter {
/// limiter.wait_with_cost(4).await;
/// println!("{:?}", Instant::now());
/// }
/// ```
///
async fn wait_with_cost(&mut self, cost: usize) {
let mut state_guard = self.state.lock().await;
(*state_guard).acquire_tokens(cost).await;
}
}
impl Drop for TokenBucketState {
fn drop(&mut self) {
if let Some(task_handle) = self.replenish_task.take() {
task_handle.abort();
}
}
}
impl TokenBucketState {
/// Create a new [`TokenBucketState`] with a full bucket of `max_tokens` that will be replenished
/// with `replace_amount` tokens every `replace_duration`.
pub fn new(max_tokens: usize, replace_amount: usize, replace_duration: Duration) -> Self {
TokenBucketState {
tokens: Arc::new(Semaphore::new(max_tokens)),
replace_amount,
replace_duration,
acquired_tokens: Arc::new(Mutex::new(vec![])),
replenish_task: None,
}
}
fn create_task_if_none(&mut self) {
if self.replenish_task.is_none() {
let acquired_tokens = self.acquired_tokens.clone();
let replace_amount = self.replace_amount;
let replace_duration = self.replace_duration;
let handle = tokio::spawn(async move {
Self::replenish_on_schedule(acquired_tokens, replace_amount, replace_duration)
.await;
});
self.replenish_task = Some(handle);
}
}
async fn acquire_tokens(&mut self, n_tokens: usize) {
self.create_task_if_none();
let mut tokens = Vec::with_capacity(n_tokens);
for _ in 0..n_tokens {
let token = self
.tokens
.clone()
.acquire_owned()
.await
.expect("Failed to acquire tokens.");
tokens.push(token);
}
let mut acquired_tokens_guard = self.acquired_tokens.lock().await;
(*acquired_tokens_guard).extend(tokens);
}
async fn replenish_on_schedule(
acquired_tokens: Arc<Mutex<Vec<OwnedSemaphorePermit>>>,
replace_amount: usize,
replace_duration: Duration,
) {
let mut interval = tokio::time::interval(replace_duration);
// tick once to avoid instant replenishment on startup
interval.tick().await;
loop {
interval.tick().await;
TokenBucketState::release_tokens(replace_amount, acquired_tokens.clone()).await;
}
}
async fn release_tokens(
replace_amount: usize,
acquired_tokens: Arc<Mutex<Vec<OwnedSemaphorePermit>>>,
) {
let mut acquired_tokens_guard = acquired_tokens.lock().await;
let release_amount = replace_amount.min(acquired_tokens_guard.len());
let owned_tokens = (*acquired_tokens_guard).drain(0..release_amount);
for token in owned_tokens.into_iter() {
drop(token);
}
}
}
impl TokenBucketRateLimiter {
/// Create a new [`TokenBucketRateLimiter`] using an established [`TokenBucketState`].
///
/// `token_bucket_state` can be a reference for just this rate limiter, or it can be shared across
/// many different rate limiters.
pub fn new(token_bucket_state: Arc<Mutex<TokenBucketState>>) -> Self {
TokenBucketRateLimiter {
state: token_bucket_state,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{pause, sleep, Instant};
#[tokio::test]
async fn test_proceeds_immediately_under_limit() {
pause();
let state = TokenBucketState::new(10, 3, Duration::from_secs(3));
let state_mutex = Arc::new(Mutex::new(state));
let mut limiter = TokenBucketRateLimiter::new(state_mutex);
let start = Instant::now();
limiter.wait_with_cost(2).await;
limiter.wait_with_cost(2).await;
limiter.wait_with_cost(2).await;
limiter.wait_with_cost(2).await;
limiter.wait_with_cost(2).await;
let end = Instant::now();
let elapsed = end - start;
assert!(elapsed < Duration::from_millis(100));
}
#[tokio::test]
async fn test_multi_cost_waits_at_limit() {
pause();
let state = TokenBucketState::new(10, 2, Duration::from_secs(3));
let state_mutex = Arc::new(Mutex::new(state));
let mut limiter = TokenBucketRateLimiter::new(state_mutex);
let start = Instant::now();
limiter.wait_with_cost(8).await;
limiter.wait_with_cost(3).await;
let end = Instant::now();
let elapsed = end - start;
assert!(elapsed > Duration::from_secs(3));
assert!(elapsed < Duration::from_secs(4));
}
#[tokio::test]
async fn test_cloned_instances_share_underlying_state() {
pause();
let state = TokenBucketState::new(10, 2, Duration::from_secs(5));
let state_mutex = Arc::new(Mutex::new(state));
let mut limiter = TokenBucketRateLimiter::new(state_mutex);
let start = Instant::now();
limiter.wait_with_cost(8).await;
// clone should share same state and have to wait
limiter.clone().wait_with_cost(3).await;
let end = Instant::now();
let elapsed = end - start;
assert!(elapsed > Duration::from_secs(5));
assert!(elapsed < Duration::from_secs(6));
}
#[tokio::test]
async fn test_bucket_does_not_overflow_over_time() {
pause();
let state = TokenBucketState::new(10, 2, Duration::from_secs(3));
let state_mutex = Arc::new(Mutex::new(state));
let mut limiter = TokenBucketRateLimiter::new(state_mutex);
// bucket should not accumulate more than max tokens when not in use
sleep(Duration::from_secs(180)).await;
let start = Instant::now();
limiter.wait_with_cost(8).await;
limiter.wait_with_cost(3).await;
let end = Instant::now();
let elapsed = end - start;
assert!(elapsed > Duration::from_secs(3));
assert!(elapsed < Duration::from_secs(4));
}
#[tokio::test]
async fn test_bucket_does_not_replace_over() {
pause();
let state = TokenBucketState::new(10, 100, Duration::from_secs(3));
let state_mutex = Arc::new(Mutex::new(state));
let mut limiter = TokenBucketRateLimiter::new(state_mutex);
// bucket should not accumulate more than max tokens when not in use, and
// replacing a large amount should not over-drain the permits Vec and panic
sleep(Duration::from_secs(180)).await;
let start = Instant::now();
limiter.wait_with_cost(8).await;
limiter.wait_with_cost(3).await;
let end = Instant::now();
let elapsed = end - start;
assert!(elapsed > Duration::from_secs(3));
assert!(elapsed < Duration::from_secs(4));
}
#[tokio::test]
async fn test_many_waiters() {
pause();
let start = Instant::now();
let mut tasks = vec![];
let state = TokenBucketState::new(10, 2, Duration::from_secs(3));
let state_mutex = Arc::new(Mutex::new(state));
for _ in 0..10 {
let task_mutex = state_mutex.clone();
let task = tokio::spawn(async move {
let mut limiter = TokenBucketRateLimiter::new(task_mutex);
limiter.wait_with_cost(5).await;
});
tasks.push(task);
}
for task in tasks.into_iter() {
let _ = task.await;
}
let end = Instant::now();
let duration = end - start;
// 10 tasks * 5 cost per = 50 tokens
// 40 needed after initial 10 tokens spent
// 40 needed / 2 replace = 20 waits of 3s = 60s
assert!(duration > Duration::from_secs(60));
assert!(duration < Duration::from_secs(61));
}
}