aws_smithy_runtime/client/retries/
classifiers.rs1use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
7use aws_smithy_runtime_api::client::retries::classifiers::{
8 ClassifyRetry, RetryAction, RetryClassifierPriority, SharedRetryClassifier,
9};
10use aws_smithy_types::retry::ProvideErrorKind;
11use std::borrow::Cow;
12use std::error::Error as StdError;
13use std::marker::PhantomData;
14
15#[derive(Debug, Default)]
17pub struct ModeledAsRetryableClassifier<E> {
18 _inner: PhantomData<E>,
19}
20
21impl<E> ModeledAsRetryableClassifier<E> {
22 pub fn new() -> Self {
24 Self {
25 _inner: PhantomData,
26 }
27 }
28
29 pub fn priority() -> RetryClassifierPriority {
31 RetryClassifierPriority::modeled_as_retryable_classifier()
32 }
33}
34
35impl<E> ClassifyRetry for ModeledAsRetryableClassifier<E>
36where
37 E: StdError + ProvideErrorKind + Send + Sync + 'static,
38{
39 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
40 let output_or_error = ctx.output_or_error();
42 let error = match output_or_error {
44 Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
45 Some(Err(err)) => err,
46 };
47 error
49 .as_operation_error()
50 .and_then(|err| err.downcast_ref::<E>())
52 .and_then(|err| err.retryable_error_kind().map(RetryAction::retryable_error))
54 .unwrap_or_default()
55 }
56
57 fn name(&self) -> &'static str {
58 "Errors Modeled As Retryable"
59 }
60
61 fn priority(&self) -> RetryClassifierPriority {
62 Self::priority()
63 }
64}
65
66#[derive(Debug, Default)]
68pub struct TransientErrorClassifier<E> {
69 _inner: PhantomData<E>,
70}
71
72impl<E> TransientErrorClassifier<E> {
73 pub fn new() -> Self {
75 Self {
76 _inner: PhantomData,
77 }
78 }
79
80 pub fn priority() -> RetryClassifierPriority {
82 RetryClassifierPriority::transient_error_classifier()
83 }
84}
85
86impl<E> ClassifyRetry for TransientErrorClassifier<E>
87where
88 E: StdError + Send + Sync + 'static,
89{
90 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
91 let output_or_error = ctx.output_or_error();
93 let error = match output_or_error {
95 Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
96 Some(Err(err)) => err,
97 };
98
99 if error.is_response_error() || error.is_timeout_error() {
100 RetryAction::transient_error()
101 } else if let Some(error) = error.as_connector_error() {
102 if error.is_timeout() || error.is_io() {
103 RetryAction::transient_error()
104 } else {
105 error
106 .as_other()
107 .map(RetryAction::retryable_error)
108 .unwrap_or_default()
109 }
110 } else {
111 RetryAction::NoActionIndicated
112 }
113 }
114
115 fn name(&self) -> &'static str {
116 "Retryable Smithy Errors"
117 }
118
119 fn priority(&self) -> RetryClassifierPriority {
120 Self::priority()
121 }
122}
123
124const TRANSIENT_ERROR_STATUS_CODES: &[u16] = &[500, 502, 503, 504];
125
126#[derive(Debug)]
129pub struct HttpStatusCodeClassifier {
130 retryable_status_codes: Cow<'static, [u16]>,
131}
132
133impl Default for HttpStatusCodeClassifier {
134 fn default() -> Self {
135 Self::new_from_codes(TRANSIENT_ERROR_STATUS_CODES.to_owned())
136 }
137}
138
139impl HttpStatusCodeClassifier {
140 pub fn new_from_codes(retryable_status_codes: impl Into<Cow<'static, [u16]>>) -> Self {
144 Self {
145 retryable_status_codes: retryable_status_codes.into(),
146 }
147 }
148
149 pub fn priority() -> RetryClassifierPriority {
151 RetryClassifierPriority::http_status_code_classifier()
152 }
153}
154
155impl ClassifyRetry for HttpStatusCodeClassifier {
156 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
157 let is_retryable = ctx
158 .response()
159 .map(|res| res.status().into())
160 .map(|status| self.retryable_status_codes.contains(&status))
161 .unwrap_or_default();
162
163 if is_retryable {
164 RetryAction::transient_error()
165 } else {
166 RetryAction::NoActionIndicated
167 }
168 }
169
170 fn name(&self) -> &'static str {
171 "HTTP Status Code"
172 }
173
174 fn priority(&self) -> RetryClassifierPriority {
175 Self::priority()
176 }
177}
178
179pub fn run_classifiers_on_ctx(
183 classifiers: impl Iterator<Item = SharedRetryClassifier>,
184 ctx: &InterceptorContext,
185) -> RetryAction {
186 let mut result = RetryAction::NoActionIndicated;
188
189 for classifier in classifiers {
190 let new_result = classifier.classify_retry(ctx);
191
192 if new_result == RetryAction::NoActionIndicated {
195 continue;
196 }
197
198 tracing::trace!(
200 "Classifier '{}' set the result of classification to '{}'",
201 classifier.name(),
202 new_result
203 );
204 result = new_result;
205
206 if result == RetryAction::RetryForbidden {
208 tracing::trace!("retry classification ending early because a `RetryAction::RetryForbidden` was emitted",);
209 break;
210 }
211 }
212
213 result
214}
215
216#[cfg(test)]
217mod test {
218 use crate::client::retries::classifiers::{
219 HttpStatusCodeClassifier, ModeledAsRetryableClassifier,
220 };
221 use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, InterceptorContext};
222 use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
223 use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
224 use aws_smithy_types::body::SdkBody;
225 use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
226 use std::fmt;
227
228 use super::TransientErrorClassifier;
229
230 #[derive(Debug, PartialEq, Eq, Clone)]
231 struct UnmodeledError;
232
233 impl fmt::Display for UnmodeledError {
234 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235 write!(f, "UnmodeledError")
236 }
237 }
238
239 impl std::error::Error for UnmodeledError {}
240
241 #[test]
242 fn classify_by_response_status() {
243 let policy = HttpStatusCodeClassifier::default();
244 let res = http_02x::Response::builder()
245 .status(500)
246 .body("error!")
247 .unwrap()
248 .map(SdkBody::from);
249 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
250 ctx.set_response(res.try_into().unwrap());
251 assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error());
252 }
253
254 #[test]
255 fn classify_by_response_status_not_retryable() {
256 let policy = HttpStatusCodeClassifier::default();
257 let res = http_02x::Response::builder()
258 .status(408)
259 .body("error!")
260 .unwrap()
261 .map(SdkBody::from);
262 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
263 ctx.set_response(res.try_into().unwrap());
264 assert_eq!(policy.classify_retry(&ctx), RetryAction::NoActionIndicated);
265 }
266
267 #[test]
268 fn classify_by_error_kind() {
269 #[derive(Debug)]
270 struct RetryableError;
271
272 impl fmt::Display for RetryableError {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 write!(f, "Some retryable error")
275 }
276 }
277
278 impl ProvideErrorKind for RetryableError {
279 fn retryable_error_kind(&self) -> Option<ErrorKind> {
280 Some(ErrorKind::ClientError)
281 }
282
283 fn code(&self) -> Option<&str> {
284 unimplemented!()
286 }
287 }
288
289 impl std::error::Error for RetryableError {}
290
291 let policy = ModeledAsRetryableClassifier::<RetryableError>::new();
292 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
293 ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
294 RetryableError,
295 ))));
296
297 assert_eq!(policy.classify_retry(&ctx), RetryAction::client_error(),);
298 }
299
300 #[test]
301 fn classify_response_error() {
302 let policy = TransientErrorClassifier::<UnmodeledError>::new();
303 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
304 ctx.set_output_or_error(Err(OrchestratorError::response(
305 "I am a response error".into(),
306 )));
307 assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error(),);
308 }
309
310 #[test]
311 fn test_timeout_error() {
312 let policy = TransientErrorClassifier::<UnmodeledError>::new();
313 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
314 ctx.set_output_or_error(Err(OrchestratorError::timeout(
315 "I am a timeout error".into(),
316 )));
317 assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error(),);
318 }
319}