reqwest_retry/
middleware.rs

1//! `RetryTransientMiddleware` implements retrying requests on transient errors.
2use std::time::{Duration, SystemTime};
3
4use crate::retryable_strategy::RetryableStrategy;
5use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy, RetryError};
6use anyhow::anyhow;
7use http::Extensions;
8use reqwest::{Request, Response};
9use reqwest_middleware::{Error, Middleware, Next, Result};
10use retry_policies::RetryPolicy;
11
12#[doc(hidden)]
13// We need this macro because tracing expects the level to be const:
14// https://github.com/tokio-rs/tracing/issues/2730
15#[cfg(feature = "tracing")]
16macro_rules! log_retry {
17    ($level:expr, $($args:tt)*) => {{
18        match $level {
19            ::tracing::Level::TRACE => ::tracing::trace!($($args)*),
20            ::tracing::Level::DEBUG => ::tracing::debug!($($args)*),
21            ::tracing::Level::INFO => ::tracing::info!($($args)*),
22            ::tracing::Level::WARN => ::tracing::warn!($($args)*),
23            ::tracing::Level::ERROR => ::tracing::error!($($args)*),
24        }
25    }};
26}
27
28/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
29/// and can be safely executed again.
30///
31/// Currently, it allows setting a [RetryPolicy] algorithm for calculating the __wait_time__
32/// between each request retry. Sleeping on non-`wasm32` archs is performed using
33/// [`tokio::time::sleep`], therefore it will respect pauses/auto-advance if run under a
34/// runtime that supports them.
35///
36///```rust
37///     use std::time::Duration;
38///     use reqwest_middleware::ClientBuilder;
39///     use retry_policies::{RetryDecision, RetryPolicy, Jitter};
40///     use retry_policies::policies::ExponentialBackoff;
41///     use reqwest_retry::RetryTransientMiddleware;
42///     use reqwest::Client;
43///
44///     // We create a ExponentialBackoff retry policy which implements `RetryPolicy`.
45///     let retry_policy = ExponentialBackoff::builder()
46///         .retry_bounds(Duration::from_secs(1), Duration::from_secs(60))
47///         .jitter(Jitter::Bounded)
48///         .base(2)
49///         .build_with_total_retry_duration(Duration::from_secs(24 * 60 * 60));
50///
51///     let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
52///     let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build();
53///```
54///
55/// # Note
56///
57/// This middleware always errors when given requests with streaming bodies, before even executing
58/// the request. When this happens you'll get an [`Error::Middleware`] with the message
59/// 'Request object is not cloneable. Are you passing a streaming body?'.
60///
61/// Some workaround suggestions:
62/// * If you can fit the data in memory, you can instead build static request bodies e.g. with
63///   `Body`'s `From<String>` or `From<Bytes>` implementations.
64/// * You can wrap this middleware in a custom one which skips retries for streaming requests.
65/// * You can write a custom retry middleware that builds new streaming requests from the data
66///   source directly, avoiding the issue of streaming requests not being cloneable.
67pub struct RetryTransientMiddleware<
68    T: RetryPolicy + Send + Sync + 'static,
69    R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
70> {
71    retry_policy: T,
72    retryable_strategy: R,
73    #[cfg(feature = "tracing")]
74    retry_log_level: tracing::Level,
75}
76
77impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> {
78    /// Construct `RetryTransientMiddleware` with  a [retry_policy][RetryPolicy].
79    pub fn new_with_policy(retry_policy: T) -> Self {
80        Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy)
81    }
82
83    /// Set the log [level][tracing::Level] for retry events.
84    /// The default is [`WARN`][tracing::Level::WARN].
85    #[cfg(feature = "tracing")]
86    pub fn with_retry_log_level(mut self, level: tracing::Level) -> Self {
87        self.retry_log_level = level;
88        self
89    }
90}
91
92impl<T, R> RetryTransientMiddleware<T, R>
93where
94    T: RetryPolicy + Send + Sync,
95    R: RetryableStrategy + Send + Sync,
96{
97    /// Construct `RetryTransientMiddleware` with  a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy).
98    pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
99        Self {
100            retry_policy,
101            retryable_strategy,
102            #[cfg(feature = "tracing")]
103            retry_log_level: tracing::Level::WARN,
104        }
105    }
106}
107
108#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
109#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
110impl<T, R> Middleware for RetryTransientMiddleware<T, R>
111where
112    T: RetryPolicy + Send + Sync,
113    R: RetryableStrategy + Send + Sync + 'static,
114{
115    async fn handle(
116        &self,
117        req: Request,
118        extensions: &mut Extensions,
119        next: Next<'_>,
120    ) -> Result<Response> {
121        // TODO: Ideally we should create a new instance of the `Extensions` map to pass
122        // downstream. This will guard against previous retries polluting `Extensions`.
123        // That is, we only return what's populated in the typemap for the last retry attempt
124        // and copy those into the the `global` Extensions map.
125        self.execute_with_retry(req, next, extensions).await
126    }
127}
128
129impl<T, R> RetryTransientMiddleware<T, R>
130where
131    T: RetryPolicy + Send + Sync,
132    R: RetryableStrategy + Send + Sync,
133{
134    /// This function will try to execute the request, if it fails
135    /// with an error classified as transient it will call itself
136    /// to retry the request.
137    async fn execute_with_retry<'a>(
138        &'a self,
139        req: Request,
140        next: Next<'a>,
141        ext: &'a mut Extensions,
142    ) -> Result<Response> {
143        let mut n_past_retries = 0;
144        let start_time = SystemTime::now();
145        loop {
146            // Cloning the request object before-the-fact is not ideal..
147            // However, if the body of the request is not static, e.g of type `Bytes`,
148            // the Clone operation should be of constant complexity and not O(N)
149            // since the byte abstraction is a shared pointer over a buffer.
150            let duplicate_request = req.try_clone().ok_or_else(|| {
151                Error::Middleware(anyhow!(
152                    "Request object is not cloneable. Are you passing a streaming body?"
153                        .to_string()
154                ))
155            })?;
156
157            let result = next.clone().run(duplicate_request, ext).await;
158
159            // We classify the response which will return None if not
160            // errors were returned.
161            if let Some(Retryable::Transient) = self.retryable_strategy.handle(&result) {
162                // If the response failed and the error type was transient
163                // we can safely try to retry the request.
164                let retry_decision = self.retry_policy.should_retry(start_time, n_past_retries);
165                if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision {
166                    let duration = execute_after
167                        .duration_since(SystemTime::now())
168                        .unwrap_or_else(|_| Duration::default());
169                    // Sleep the requested amount before we try again.
170                    #[cfg(feature = "tracing")]
171                    log_retry!(
172                        self.retry_log_level,
173                        "Retry attempt #{}. Sleeping {:?} before the next attempt",
174                        n_past_retries,
175                        duration
176                    );
177                    #[cfg(not(target_arch = "wasm32"))]
178                    tokio::time::sleep(duration).await;
179                    #[cfg(target_arch = "wasm32")]
180                    wasm_timer::Delay::new(duration)
181                        .await
182                        .expect("failed sleeping");
183
184                    n_past_retries += 1;
185                    continue;
186                }
187            };
188
189            // Report whether we failed with or without retries.
190            break if n_past_retries > 0 {
191                result.map_err(|err| {
192                    Error::Middleware(
193                        RetryError::WithRetries {
194                            retries: n_past_retries,
195                            err,
196                        }
197                        .into(),
198                    )
199                })
200            } else {
201                result.map_err(|err| Error::Middleware(RetryError::Error(err).into()))
202            };
203        }
204    }
205}