1use crate::client::waiters::backoff::{Backoff, RandomImpl};
7use aws_smithy_async::{
8 rt::sleep::{AsyncSleep, SharedAsyncSleep},
9 time::SharedTimeSource,
10};
11use aws_smithy_runtime_api::client::waiters::FinalPoll;
12use aws_smithy_runtime_api::client::{orchestrator::HttpResponse, result::SdkError};
13use aws_smithy_runtime_api::client::{
14 result::CreateUnhandledError,
15 waiters::error::{ExceededMaxWait, FailureState, OperationFailed, WaiterError},
16};
17use std::future::Future;
18use std::time::Duration;
19
20mod backoff;
21
22#[non_exhaustive]
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33pub enum AcceptorState {
34 NoAcceptorsMatched,
36 Success,
38 Failure,
40 Retry,
42}
43
44pub struct WaiterOrchestrator<AcceptorFn, OperationFn> {
49 backoff: Backoff,
50 time_source: SharedTimeSource,
51 sleep_impl: SharedAsyncSleep,
52 acceptor_fn: AcceptorFn,
53 operation_fn: OperationFn,
54}
55
56impl WaiterOrchestrator<(), ()> {
57 pub fn builder() -> WaiterOrchestratorBuilder<(), ()> {
59 WaiterOrchestratorBuilder::default()
60 }
61}
62
63impl<AcceptorFn, OperationFn> WaiterOrchestrator<AcceptorFn, OperationFn> {
64 fn new(
65 backoff: Backoff,
66 time_source: SharedTimeSource,
67 sleep_impl: SharedAsyncSleep,
68 acceptor_fn: AcceptorFn,
69 operation_fn: OperationFn,
70 ) -> Self {
71 WaiterOrchestrator {
72 backoff,
73 time_source,
74 sleep_impl,
75 acceptor_fn,
76 operation_fn,
77 }
78 }
79}
80
81impl<AcceptorFn, OperationFn, O, E, Fut> WaiterOrchestrator<AcceptorFn, OperationFn>
82where
83 AcceptorFn: Fn(Result<&O, &E>) -> AcceptorState,
84 OperationFn: Fn() -> Fut,
85 Fut: Future<Output = Result<O, SdkError<E, HttpResponse>>>,
86 E: CreateUnhandledError + std::error::Error + Send + Sync + 'static,
87{
88 pub async fn orchestrate(
90 self,
91 ) -> Result<FinalPoll<O, SdkError<E, HttpResponse>>, WaiterError<O, E>> {
92 let start_time = self.time_source.now();
93 let mut attempt = 0;
94 let mut done_retrying = false;
95 loop {
96 tracing::debug!("executing waiter poll attempt #{}", attempt + 1);
97 let result = (self.operation_fn)().await;
98 let error = result.is_err();
99
100 let acceptable_result = result.as_ref().map_err(|err| err.as_service_error());
102 let acceptor_state = match acceptable_result {
103 Ok(output) => (self.acceptor_fn)(Ok(output)),
104 Err(Some(err)) => (self.acceptor_fn)(Err(err)),
105 _ => {
106 return Err(WaiterError::OperationFailed(OperationFailed::new(
108 result.err().expect("can only be an err in this branch"),
109 )));
110 }
111 };
112
113 tracing::debug!("waiter acceptor state: {acceptor_state:?}");
114 match acceptor_state {
115 AcceptorState::Success => return Ok(FinalPoll::new(result)),
116 AcceptorState::Failure => {
117 return Err(WaiterError::FailureState(FailureState::new(
118 FinalPoll::new(result.map_err(|err| err.into_service_error())),
119 )))
120 }
121 AcceptorState::NoAcceptorsMatched if error => {
123 return Err(WaiterError::OperationFailed(OperationFailed::new(
124 result.err().expect("checked above"),
125 )))
126 }
127 AcceptorState::Retry | AcceptorState::NoAcceptorsMatched => {
128 attempt += 1;
129
130 let now = self.time_source.now();
131 let elapsed = now.duration_since(start_time).unwrap_or_default();
132 if !done_retrying && elapsed <= self.backoff.max_wait() {
133 let delay = self.backoff.delay(attempt, elapsed);
134
135 if delay.is_zero() {
140 tracing::debug!(
141 "delay calculated for attempt #{attempt}; elapsed ({elapsed:?}); waiter is close to max time; will immediately poll one last time"
142 );
143 done_retrying = true;
144 } else {
145 tracing::debug!(
146 "delay calculated for attempt #{attempt}; elapsed ({elapsed:?}); waiter will poll again in {delay:?}"
147 );
148 self.sleep_impl.sleep(delay).await;
149 }
150 } else {
151 tracing::debug!(
152 "waiter exceeded max wait time of {:?}",
153 self.backoff.max_wait()
154 );
155 return Err(WaiterError::ExceededMaxWait(ExceededMaxWait::new(
156 self.backoff.max_wait(),
157 elapsed,
158 attempt,
159 )));
160 }
161 }
162 }
163 }
164 }
165}
166
167#[derive(Default)]
169pub struct WaiterOrchestratorBuilder<AcceptorFn = (), OperationFn = ()> {
170 min_delay: Option<Duration>,
171 max_delay: Option<Duration>,
172 max_wait: Option<Duration>,
173 time_source: Option<SharedTimeSource>,
174 sleep_impl: Option<SharedAsyncSleep>,
175 random_fn: RandomImpl,
176 acceptor_fn: Option<AcceptorFn>,
177 operation_fn: Option<OperationFn>,
178}
179
180impl<AcceptorFn, OperationFn> WaiterOrchestratorBuilder<AcceptorFn, OperationFn> {
181 pub fn min_delay(mut self, min_delay: Duration) -> Self {
183 self.min_delay = Some(min_delay);
184 self
185 }
186
187 pub fn max_delay(mut self, max_delay: Duration) -> Self {
189 self.max_delay = Some(max_delay);
190 self
191 }
192
193 pub fn max_wait(mut self, max_wait: Duration) -> Self {
195 self.max_wait = Some(max_wait);
196 self
197 }
198
199 #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
200 fn random(mut self, random_fn: impl Fn(u64, u64) -> u64 + Send + Sync + 'static) -> Self {
201 self.random_fn = RandomImpl::Override(Box::new(random_fn));
202 self
203 }
204
205 pub fn time_source(mut self, time_source: SharedTimeSource) -> Self {
207 self.time_source = Some(time_source);
208 self
209 }
210
211 pub fn sleep_impl(mut self, sleep_impl: SharedAsyncSleep) -> Self {
213 self.sleep_impl = Some(sleep_impl);
214 self
215 }
216
217 pub fn build(self) -> WaiterOrchestrator<AcceptorFn, OperationFn> {
219 WaiterOrchestrator::new(
220 Backoff::new(
221 self.min_delay.expect("min delay is required"),
222 self.max_delay.expect("max delay is required"),
223 self.max_wait.expect("max wait is required"),
224 self.random_fn,
225 ),
226 self.time_source.expect("time source required"),
227 self.sleep_impl.expect("sleep impl required"),
228 self.acceptor_fn.expect("acceptor fn required"),
229 self.operation_fn.expect("operation fn required"),
230 )
231 }
232}
233
234impl<OperationFn> WaiterOrchestratorBuilder<(), OperationFn> {
235 pub fn acceptor<AcceptorFn>(
237 self,
238 acceptor: AcceptorFn,
239 ) -> WaiterOrchestratorBuilder<AcceptorFn, OperationFn> {
240 WaiterOrchestratorBuilder {
241 min_delay: self.min_delay,
242 max_delay: self.max_delay,
243 max_wait: self.max_wait,
244 time_source: self.time_source,
245 sleep_impl: self.sleep_impl,
246 random_fn: self.random_fn,
247 acceptor_fn: Some(acceptor),
248 operation_fn: self.operation_fn,
249 }
250 }
251}
252
253impl<AcceptorFn> WaiterOrchestratorBuilder<AcceptorFn, ()> {
254 pub fn operation<OperationFn>(
256 self,
257 operation: OperationFn,
258 ) -> WaiterOrchestratorBuilder<AcceptorFn, OperationFn> {
259 WaiterOrchestratorBuilder {
260 min_delay: self.min_delay,
261 max_delay: self.max_delay,
262 max_wait: self.max_wait,
263 time_source: self.time_source,
264 sleep_impl: self.sleep_impl,
265 random_fn: self.random_fn,
266 acceptor_fn: self.acceptor_fn,
267 operation_fn: Some(operation),
268 }
269 }
270}
271
272pub fn attach_waiter_tracing_span<O, E>(
275 future: impl Future<Output = Result<FinalPoll<O, SdkError<E, HttpResponse>>, WaiterError<O, E>>>,
276) -> impl Future<Output = Result<FinalPoll<O, SdkError<E, HttpResponse>>, WaiterError<O, E>>> {
277 use tracing::Instrument;
278
279 let span = tracing::debug_span!("waiter", waiter_id = fastrand::u32(1_000_000..10_000_000));
281 future.instrument(span)
282}
283
284#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
285mod tests {
286 use super::*;
287 use crate::test_util::capture_test_logs::show_test_logs;
288 use aws_smithy_async::{
289 test_util::tick_advance_sleep::tick_advance_time_and_sleep, time::TimeSource,
290 };
291 use aws_smithy_runtime_api::{http::StatusCode, shared::IntoShared};
292 use aws_smithy_types::body::SdkBody;
293 use std::{
294 fmt,
295 sync::{
296 atomic::{AtomicUsize, Ordering},
297 Arc, Mutex,
298 },
299 time::SystemTime,
300 };
301
302 #[derive(Debug)]
303 struct TestError;
304 impl std::error::Error for TestError {}
305 impl fmt::Display for TestError {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 f.write_str("TestError")
308 }
309 }
310 impl CreateUnhandledError for TestError {
311 fn create_unhandled_error(
312 _source: Box<dyn std::error::Error + Send + Sync + 'static>,
313 _meta: Option<aws_smithy_types::error::ErrorMetadata>,
314 ) -> Self {
315 unreachable!("If this is called, there is a bug in the orchestrator implementation. Unmodeled errors should never make it into FailureState.")
316 }
317 }
318
319 fn test_orchestrator(
320 sleep_impl: impl IntoShared<SharedAsyncSleep>,
321 time_source: impl IntoShared<SharedTimeSource>,
322 ) -> WaiterOrchestratorBuilder<(), ()> {
323 let test_random = |min: u64, max: u64| (min + max) / 2;
324 WaiterOrchestrator::builder()
325 .min_delay(Duration::from_secs(2))
326 .max_delay(Duration::from_secs(120))
327 .max_wait(Duration::from_secs(300))
328 .random(test_random)
329 .sleep_impl(sleep_impl.into_shared())
330 .time_source(time_source.into_shared())
331 }
332
333 #[tokio::test]
334 async fn immediate_success() {
335 let _logs = show_test_logs();
336 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
337 let orchestrator = test_orchestrator(sleep_impl, time_source)
338 .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Success)
339 .operation(|| async { Result::<_, SdkError<TestError, HttpResponse>>::Ok(5usize) })
340 .build();
341
342 let result = orchestrator.orchestrate().await;
343 assert!(result.is_ok());
344 assert_eq!(5, *result.unwrap().as_result().unwrap());
345 }
346
347 #[tokio::test]
348 async fn immediate_failure() {
349 let _logs = show_test_logs();
350 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
351 let orchestrator = test_orchestrator(sleep_impl, time_source)
352 .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Failure)
353 .operation(|| async { Result::<_, SdkError<TestError, HttpResponse>>::Ok(5usize) })
354 .build();
355
356 let result = orchestrator.orchestrate().await;
357 assert!(
358 matches!(result, Err(WaiterError::FailureState(_))),
359 "expected failure state, got: {result:?}"
360 );
361 }
362
363 #[tokio::test]
364 async fn five_polls_then_success() {
365 let _logs = show_test_logs();
366
367 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
368
369 let acceptor = |result: Result<&usize, &TestError>| match result {
370 Err(_) => unreachable!(),
371 Ok(5) => AcceptorState::Success,
372 _ => AcceptorState::Retry,
373 };
374
375 let times = Arc::new(Mutex::new(Vec::new()));
376 let attempt = Arc::new(AtomicUsize::new(1));
377 let operation = {
378 let sleep_impl = sleep_impl.clone();
379 let time_source = time_source.clone();
380 let times = times.clone();
381 move || {
382 let attempt = attempt.clone();
383 let sleep_impl = sleep_impl.clone();
384 let time_source = time_source.clone();
385 let times = times.clone();
386 async move {
387 sleep_impl.sleep(Duration::from_secs(1)).await;
389 times.lock().unwrap().push(
390 time_source
391 .now()
392 .duration_since(SystemTime::UNIX_EPOCH)
393 .unwrap()
394 .as_secs(),
395 );
396 Result::<_, SdkError<TestError, HttpResponse>>::Ok(
397 attempt.fetch_add(1, Ordering::SeqCst),
398 )
399 }
400 }
401 };
402
403 let orchestrator = test_orchestrator(sleep_impl.clone(), time_source.clone())
404 .acceptor(acceptor)
405 .operation(operation)
406 .build();
407
408 let task = tokio::spawn(orchestrator.orchestrate());
409 tokio::task::yield_now().await;
410 time_source.tick(Duration::from_secs(500)).await;
411 let result = task.await.unwrap();
412
413 assert!(result.is_ok());
414 assert_eq!(5, *result.unwrap().as_result().unwrap());
415 assert_eq!(vec![1, 4, 8, 14, 24], *times.lock().unwrap());
416 }
417
418 #[tokio::test]
419 async fn exceed_max_wait_time() {
420 let _logs = show_test_logs();
421 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
422
423 let orchestrator = test_orchestrator(sleep_impl.clone(), time_source.clone())
424 .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Retry)
425 .operation(|| async { Result::<_, SdkError<TestError, HttpResponse>>::Ok(1) })
426 .build();
427
428 let task = tokio::spawn(orchestrator.orchestrate());
429 tokio::task::yield_now().await;
430 time_source.tick(Duration::from_secs(500)).await;
431 let result = task.await.unwrap();
432
433 match result {
434 Err(WaiterError::ExceededMaxWait(context)) => {
435 assert_eq!(Duration::from_secs(300), context.max_wait());
436 assert_eq!(300, context.elapsed().as_secs());
437 assert_eq!(12, context.poll_count());
438 }
439 _ => panic!("expected ExceededMaxWait, got {result:?}"),
440 }
441 }
442
443 #[tokio::test]
444 async fn operation_timed_out() {
445 let _logs = show_test_logs();
446 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
447 let orchestrator = test_orchestrator(sleep_impl, time_source)
448 .acceptor(|_result: Result<&usize, &TestError>| unreachable!())
449 .operation(|| async {
450 Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::timeout_error(
451 "test",
452 ))
453 })
454 .build();
455
456 match orchestrator.orchestrate().await {
457 Err(WaiterError::OperationFailed(err)) => match err.error() {
458 SdkError::TimeoutError(_) => { }
459 result => panic!("unexpected final poll: {result:?}"),
460 },
461 result => panic!("unexpected result: {result:?}"),
462 }
463 }
464
465 #[tokio::test]
466 async fn modeled_service_error_no_acceptors_matched() {
467 let _logs = show_test_logs();
468 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
469 let orchestrator = test_orchestrator(sleep_impl, time_source)
470 .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::NoAcceptorsMatched)
471 .operation(|| async {
472 Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::service_error(
473 TestError,
474 HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty()),
475 ))
476 })
477 .build();
478
479 match dbg!(orchestrator.orchestrate().await) {
480 Err(WaiterError::OperationFailed(err)) => match err.error() {
481 SdkError::ServiceError(_) => { }
482 result => panic!("unexpected result: {result:?}"),
483 },
484 result => panic!("unexpected result: {result:?}"),
485 }
486 }
487
488 #[tokio::test]
489 async fn modeled_error_matched_as_failure() {
490 let _logs = show_test_logs();
491 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
492 let orchestrator = test_orchestrator(sleep_impl, time_source)
493 .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Failure)
494 .operation(|| async {
495 Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::service_error(
496 TestError,
497 HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty()),
498 ))
499 })
500 .build();
501
502 match orchestrator.orchestrate().await {
503 Err(WaiterError::FailureState(err)) => match err.final_poll().as_result() {
504 Err(TestError) => { }
505 result => panic!("unexpected final poll: {result:?}"),
506 },
507 result => panic!("unexpected result: {result:?}"),
508 }
509 }
510
511 #[tokio::test]
512 async fn modeled_error_matched_as_success() {
513 let _logs = show_test_logs();
514 let (time_source, sleep_impl) = tick_advance_time_and_sleep();
515 let orchestrator = test_orchestrator(sleep_impl, time_source)
516 .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Success)
517 .operation(|| async {
518 Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::service_error(
519 TestError,
520 HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty()),
521 ))
522 })
523 .build();
524
525 let result = orchestrator.orchestrate().await;
526 assert!(result.is_ok());
527 assert!(result.unwrap().as_result().is_err());
528 }
529}