reqwest_retry/
retryable_strategy.rs

1use crate::retryable::Retryable;
2use http::StatusCode;
3use reqwest_middleware::Error;
4
5/// A strategy to create a [`Retryable`] from a [`Result<reqwest::Response, reqwest_middleware::Error>`]
6///
7/// A [`RetryableStrategy`] has a single `handler` functions.
8/// The result of calling the request could be:
9/// - [`reqwest::Response`] In case the request has been sent and received correctly
10///     This could however still mean that the server responded with a erroneous response.
11///     For example a HTTP statuscode of 500
12/// - [`reqwest_middleware::Error`] In this case the request actually failed.
13///     This could, for example, be caused by a timeout on the connection.
14///
15/// Example:
16///
17/// ```
18/// use reqwest_retry::{default_on_request_failure, policies::ExponentialBackoff, Retryable, RetryableStrategy, RetryTransientMiddleware};
19/// use reqwest::{Request, Response};
20/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
21/// use http::Extensions;
22///
23/// // Log each request to show that the requests will be retried
24/// struct LoggingMiddleware;
25///
26/// #[async_trait::async_trait]
27/// impl Middleware for LoggingMiddleware {
28///     async fn handle(
29///         &self,
30///         req: Request,
31///         extensions: &mut Extensions,
32///         next: Next<'_>,
33///     ) -> Result<Response> {
34///         println!("Request started {}", req.url());
35///         let res = next.run(req, extensions).await;
36///         println!("Request finished");
37///         res
38///     }
39/// }
40///
41/// // Just a toy example, retry when the successful response code is 201, else do nothing.
42/// struct Retry201;
43/// impl RetryableStrategy for Retry201 {
44///     fn handle(&self, res: &Result<reqwest::Response>) -> Option<Retryable> {
45///          match res {
46///              // retry if 201
47///              Ok(success) if success.status() == 201 => Some(Retryable::Transient),
48///              // otherwise do not retry a successful request
49///              Ok(success) => None,
50///              // but maybe retry a request failure
51///              Err(error) => default_on_request_failure(error),
52///         }
53///     }
54/// }
55///
56/// #[tokio::main]
57/// async fn main() {
58///     // Exponential backoff with max 2 retries
59///     let retry_policy = ExponentialBackoff::builder()
60///         .build_with_max_retries(2);
61///     
62///     // Create the actual middleware, with the exponential backoff and custom retry strategy.
63///     let ret_s = RetryTransientMiddleware::new_with_policy_and_strategy(
64///         retry_policy,
65///         Retry201,
66///     );
67///
68///     let client = ClientBuilder::new(reqwest::Client::new())
69///         // Retry failed requests.
70///         .with(ret_s)
71///         // Log the requests
72///         .with(LoggingMiddleware)
73///         .build();
74///
75///     // Send request which should get a 201 response. So it will be retried
76///     let r = client   
77///         .get("https://httpbin.org/status/201")
78///         .send()
79///         .await;
80///     println!("{:?}", r);
81///
82///     // Send request which should get a 200 response. So it will not be retried
83///     let r = client   
84///         .get("https://httpbin.org/status/200")
85///         .send()
86///         .await;
87///     println!("{:?}", r);
88/// }
89/// ```
90pub trait RetryableStrategy {
91    fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable>;
92}
93
94/// The default [`RetryableStrategy`] for [`RetryTransientMiddleware`](crate::RetryTransientMiddleware).
95pub struct DefaultRetryableStrategy;
96
97impl RetryableStrategy for DefaultRetryableStrategy {
98    fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable> {
99        match res {
100            Ok(success) => default_on_request_success(success),
101            Err(error) => default_on_request_failure(error),
102        }
103    }
104}
105
106/// Default request success retry strategy.
107///
108/// Will only retry if:
109/// * The status was 5XX (server error)
110/// * The status was 408 (request timeout) or 429 (too many requests)
111///
112/// Note that success here means that the request finished without interruption, not that it was logically OK.
113pub fn default_on_request_success(success: &reqwest::Response) -> Option<Retryable> {
114    let status = success.status();
115    if status.is_server_error() {
116        Some(Retryable::Transient)
117    } else if status.is_client_error()
118        && status != StatusCode::REQUEST_TIMEOUT
119        && status != StatusCode::TOO_MANY_REQUESTS
120    {
121        Some(Retryable::Fatal)
122    } else if status.is_success() {
123        None
124    } else if status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS {
125        Some(Retryable::Transient)
126    } else {
127        Some(Retryable::Fatal)
128    }
129}
130
131/// Default request failure retry strategy.
132///
133/// Will only retry if the request failed due to a network error
134pub fn default_on_request_failure(error: &Error) -> Option<Retryable> {
135    match error {
136        // If something fails in the middleware we're screwed.
137        Error::Middleware(_) => Some(Retryable::Fatal),
138        Error::Reqwest(error) => {
139            #[cfg(not(target_arch = "wasm32"))]
140            let is_connect = error.is_connect();
141            #[cfg(target_arch = "wasm32")]
142            let is_connect = false;
143            if error.is_timeout() || is_connect {
144                Some(Retryable::Transient)
145            } else if error.is_body()
146                || error.is_decode()
147                || error.is_builder()
148                || error.is_redirect()
149            {
150                Some(Retryable::Fatal)
151            } else if error.is_request() {
152                // It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
153                // Here we check if the Reqwest error was originated by hyper and map it consistently.
154                #[cfg(not(target_arch = "wasm32"))]
155                if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
156                    // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
157                    // This can happen when the server has started sending back the response but the connection is cut halfway through.
158                    // We can safely retry the call, hence marking this error as [`Retryable::Transient`].
159                    // Instead hyper::Error(Canceled) is raised when the connection is
160                    // gracefully closed on the server side.
161                    if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
162                        Some(Retryable::Transient)
163
164                    // Try and downcast the hyper error to io::Error if that is the
165                    // underlying error, and try and classify it.
166                    } else if let Some(io_error) =
167                        get_source_error_type::<std::io::Error>(hyper_error)
168                    {
169                        Some(classify_io_error(io_error))
170                    } else {
171                        Some(Retryable::Fatal)
172                    }
173                } else {
174                    Some(Retryable::Fatal)
175                }
176                #[cfg(target_arch = "wasm32")]
177                Some(Retryable::Fatal)
178            } else {
179                // We omit checking if error.is_status() since we check that already.
180                // However, if Response::error_for_status is used the status will still
181                // remain in the response object.
182                None
183            }
184        }
185    }
186}
187
188#[cfg(not(target_arch = "wasm32"))]
189fn classify_io_error(error: &std::io::Error) -> Retryable {
190    match error.kind() {
191        std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => {
192            Retryable::Transient
193        }
194        _ => Retryable::Fatal,
195    }
196}
197
198/// Downcasts the given err source into T.
199#[cfg(not(target_arch = "wasm32"))]
200fn get_source_error_type<T: std::error::Error + 'static>(
201    err: &dyn std::error::Error,
202) -> Option<&T> {
203    let mut source = err.source();
204
205    while let Some(err) = source {
206        if let Some(err) = err.downcast_ref::<T>() {
207            return Some(err);
208        }
209
210        source = err.source();
211    }
212    None
213}