reqwest_retry/middleware.rs
1//! `RetryTransientMiddleware` implements retrying requests on transient errors.
2use std::time::{Duration, SystemTime};
3
4use crate::retryable_strategy::RetryableStrategy;
5use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy, RetryError};
6use anyhow::anyhow;
7use http::Extensions;
8use reqwest::{Request, Response};
9use reqwest_middleware::{Error, Middleware, Next, Result};
10use retry_policies::RetryPolicy;
11
12#[doc(hidden)]
13// We need this macro because tracing expects the level to be const:
14// https://github.com/tokio-rs/tracing/issues/2730
15#[cfg(feature = "tracing")]
16macro_rules! log_retry {
17 ($level:expr, $($args:tt)*) => {{
18 match $level {
19 ::tracing::Level::TRACE => ::tracing::trace!($($args)*),
20 ::tracing::Level::DEBUG => ::tracing::debug!($($args)*),
21 ::tracing::Level::INFO => ::tracing::info!($($args)*),
22 ::tracing::Level::WARN => ::tracing::warn!($($args)*),
23 ::tracing::Level::ERROR => ::tracing::error!($($args)*),
24 }
25 }};
26}
27
28/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
29/// and can be safely executed again.
30///
31/// Currently, it allows setting a [RetryPolicy] algorithm for calculating the __wait_time__
32/// between each request retry. Sleeping on non-`wasm32` archs is performed using
33/// [`tokio::time::sleep`], therefore it will respect pauses/auto-advance if run under a
34/// runtime that supports them.
35///
36///```rust
37/// use std::time::Duration;
38/// use reqwest_middleware::ClientBuilder;
39/// use retry_policies::{RetryDecision, RetryPolicy, Jitter};
40/// use retry_policies::policies::ExponentialBackoff;
41/// use reqwest_retry::RetryTransientMiddleware;
42/// use reqwest::Client;
43///
44/// // We create a ExponentialBackoff retry policy which implements `RetryPolicy`.
45/// let retry_policy = ExponentialBackoff::builder()
46/// .retry_bounds(Duration::from_secs(1), Duration::from_secs(60))
47/// .jitter(Jitter::Bounded)
48/// .base(2)
49/// .build_with_total_retry_duration(Duration::from_secs(24 * 60 * 60));
50///
51/// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
52/// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build();
53///```
54///
55/// # Note
56///
57/// This middleware always errors when given requests with streaming bodies, before even executing
58/// the request. When this happens you'll get an [`Error::Middleware`] with the message
59/// 'Request object is not cloneable. Are you passing a streaming body?'.
60///
61/// Some workaround suggestions:
62/// * If you can fit the data in memory, you can instead build static request bodies e.g. with
63/// `Body`'s `From<String>` or `From<Bytes>` implementations.
64/// * You can wrap this middleware in a custom one which skips retries for streaming requests.
65/// * You can write a custom retry middleware that builds new streaming requests from the data
66/// source directly, avoiding the issue of streaming requests not being cloneable.
67pub struct RetryTransientMiddleware<
68 T: RetryPolicy + Send + Sync + 'static,
69 R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
70> {
71 retry_policy: T,
72 retryable_strategy: R,
73 #[cfg(feature = "tracing")]
74 retry_log_level: tracing::Level,
75}
76
77impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> {
78 /// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy].
79 pub fn new_with_policy(retry_policy: T) -> Self {
80 Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy)
81 }
82
83 /// Set the log [level][tracing::Level] for retry events.
84 /// The default is [`WARN`][tracing::Level::WARN].
85 #[cfg(feature = "tracing")]
86 pub fn with_retry_log_level(mut self, level: tracing::Level) -> Self {
87 self.retry_log_level = level;
88 self
89 }
90}
91
92impl<T, R> RetryTransientMiddleware<T, R>
93where
94 T: RetryPolicy + Send + Sync,
95 R: RetryableStrategy + Send + Sync,
96{
97 /// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy).
98 pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
99 Self {
100 retry_policy,
101 retryable_strategy,
102 #[cfg(feature = "tracing")]
103 retry_log_level: tracing::Level::WARN,
104 }
105 }
106}
107
108#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
109#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
110impl<T, R> Middleware for RetryTransientMiddleware<T, R>
111where
112 T: RetryPolicy + Send + Sync,
113 R: RetryableStrategy + Send + Sync + 'static,
114{
115 async fn handle(
116 &self,
117 req: Request,
118 extensions: &mut Extensions,
119 next: Next<'_>,
120 ) -> Result<Response> {
121 // TODO: Ideally we should create a new instance of the `Extensions` map to pass
122 // downstream. This will guard against previous retries polluting `Extensions`.
123 // That is, we only return what's populated in the typemap for the last retry attempt
124 // and copy those into the the `global` Extensions map.
125 self.execute_with_retry(req, next, extensions).await
126 }
127}
128
129impl<T, R> RetryTransientMiddleware<T, R>
130where
131 T: RetryPolicy + Send + Sync,
132 R: RetryableStrategy + Send + Sync,
133{
134 /// This function will try to execute the request, if it fails
135 /// with an error classified as transient it will call itself
136 /// to retry the request.
137 async fn execute_with_retry<'a>(
138 &'a self,
139 req: Request,
140 next: Next<'a>,
141 ext: &'a mut Extensions,
142 ) -> Result<Response> {
143 let mut n_past_retries = 0;
144 let start_time = SystemTime::now();
145 loop {
146 // Cloning the request object before-the-fact is not ideal..
147 // However, if the body of the request is not static, e.g of type `Bytes`,
148 // the Clone operation should be of constant complexity and not O(N)
149 // since the byte abstraction is a shared pointer over a buffer.
150 let duplicate_request = req.try_clone().ok_or_else(|| {
151 Error::Middleware(anyhow!(
152 "Request object is not cloneable. Are you passing a streaming body?"
153 .to_string()
154 ))
155 })?;
156
157 let result = next.clone().run(duplicate_request, ext).await;
158
159 // We classify the response which will return None if not
160 // errors were returned.
161 if let Some(Retryable::Transient) = self.retryable_strategy.handle(&result) {
162 // If the response failed and the error type was transient
163 // we can safely try to retry the request.
164 let retry_decision = self.retry_policy.should_retry(start_time, n_past_retries);
165 if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision {
166 let duration = execute_after
167 .duration_since(SystemTime::now())
168 .unwrap_or_else(|_| Duration::default());
169 // Sleep the requested amount before we try again.
170 #[cfg(feature = "tracing")]
171 log_retry!(
172 self.retry_log_level,
173 "Retry attempt #{}. Sleeping {:?} before the next attempt",
174 n_past_retries,
175 duration
176 );
177 #[cfg(not(target_arch = "wasm32"))]
178 tokio::time::sleep(duration).await;
179 #[cfg(target_arch = "wasm32")]
180 wasm_timer::Delay::new(duration)
181 .await
182 .expect("failed sleeping");
183
184 n_past_retries += 1;
185 continue;
186 }
187 };
188
189 // Report whether we failed with or without retries.
190 break if n_past_retries > 0 {
191 result.map_err(|err| {
192 Error::Middleware(
193 RetryError::WithRetries {
194 retries: n_past_retries,
195 err,
196 }
197 .into(),
198 )
199 })
200 } else {
201 result.map_err(|err| Error::Middleware(RetryError::Error(err).into()))
202 };
203 }
204 }
205}