aws_smithy_runtime/client/
waiters.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::waiters::backoff::{Backoff, RandomImpl};
7use aws_smithy_async::{
8    rt::sleep::{AsyncSleep, SharedAsyncSleep},
9    time::SharedTimeSource,
10};
11use aws_smithy_runtime_api::client::waiters::FinalPoll;
12use aws_smithy_runtime_api::client::{orchestrator::HttpResponse, result::SdkError};
13use aws_smithy_runtime_api::client::{
14    result::CreateUnhandledError,
15    waiters::error::{ExceededMaxWait, FailureState, OperationFailed, WaiterError},
16};
17use std::future::Future;
18use std::time::Duration;
19
20mod backoff;
21
22/// Waiter acceptor state
23///
24/// This enum (vaguely) matches the [acceptor state] from the Smithy spec.
25/// It has an additional `NoAcceptorsMatched` variant to indicate the case where
26/// none of the modeled waiters matched the response, which the spec mentions but
27/// doesn't consider an official part of the acceptor state enum. An `Option<AcceptorState>`
28/// could have been used instead, but this seemed cleaner.
29///
30/// [acceptor state]: https://smithy.io/2.0/additional-specs/waiters.html#acceptorstate-enum
31#[non_exhaustive]
32#[derive(Copy, Clone, Debug, Eq, PartialEq)]
33pub enum AcceptorState {
34    /// None of the modeled acceptors matched the response.
35    NoAcceptorsMatched,
36    /// A `success` acceptor matched the response.
37    Success,
38    /// A `failure` acceptor matched the response.
39    Failure,
40    /// A `retry` acceptor matched the response.
41    Retry,
42}
43
44/// Orchestrates waiting via polling with jittered exponential backoff.
45///
46/// This is meant to be used internally by the generated code to provide
47/// waiter functionality.
48pub struct WaiterOrchestrator<AcceptorFn, OperationFn> {
49    backoff: Backoff,
50    time_source: SharedTimeSource,
51    sleep_impl: SharedAsyncSleep,
52    acceptor_fn: AcceptorFn,
53    operation_fn: OperationFn,
54}
55
56impl WaiterOrchestrator<(), ()> {
57    /// Returns a builder for the waiter orchestrator.
58    pub fn builder() -> WaiterOrchestratorBuilder<(), ()> {
59        WaiterOrchestratorBuilder::default()
60    }
61}
62
63impl<AcceptorFn, OperationFn> WaiterOrchestrator<AcceptorFn, OperationFn> {
64    fn new(
65        backoff: Backoff,
66        time_source: SharedTimeSource,
67        sleep_impl: SharedAsyncSleep,
68        acceptor_fn: AcceptorFn,
69        operation_fn: OperationFn,
70    ) -> Self {
71        WaiterOrchestrator {
72            backoff,
73            time_source,
74            sleep_impl,
75            acceptor_fn,
76            operation_fn,
77        }
78    }
79}
80
81impl<AcceptorFn, OperationFn, O, E, Fut> WaiterOrchestrator<AcceptorFn, OperationFn>
82where
83    AcceptorFn: Fn(Result<&O, &E>) -> AcceptorState,
84    OperationFn: Fn() -> Fut,
85    Fut: Future<Output = Result<O, SdkError<E, HttpResponse>>>,
86    E: CreateUnhandledError + std::error::Error + Send + Sync + 'static,
87{
88    /// Orchestrates waiting via polling with jittered exponential backoff.
89    pub async fn orchestrate(
90        self,
91    ) -> Result<FinalPoll<O, SdkError<E, HttpResponse>>, WaiterError<O, E>> {
92        let start_time = self.time_source.now();
93        let mut attempt = 0;
94        let mut done_retrying = false;
95        loop {
96            tracing::debug!("executing waiter poll attempt #{}", attempt + 1);
97            let result = (self.operation_fn)().await;
98            let error = result.is_err();
99
100            // "acceptable result" in this context means "an acceptor's matcher can match this result type"
101            let acceptable_result = result.as_ref().map_err(|err| err.as_service_error());
102            let acceptor_state = match acceptable_result {
103                Ok(output) => (self.acceptor_fn)(Ok(output)),
104                Err(Some(err)) => (self.acceptor_fn)(Err(err)),
105                _ => {
106                    // If we got an unmatchable failure (basically anything unmodeled), then just immediately exit
107                    return Err(WaiterError::OperationFailed(OperationFailed::new(
108                        result.err().expect("can only be an err in this branch"),
109                    )));
110                }
111            };
112
113            tracing::debug!("waiter acceptor state: {acceptor_state:?}");
114            match acceptor_state {
115                AcceptorState::Success => return Ok(FinalPoll::new(result)),
116                AcceptorState::Failure => {
117                    return Err(WaiterError::FailureState(FailureState::new(
118                        FinalPoll::new(result.map_err(|err| err.into_service_error())),
119                    )))
120                }
121                // This occurs when there was a modeled error response, but none of the acceptors matched it
122                AcceptorState::NoAcceptorsMatched if error => {
123                    return Err(WaiterError::OperationFailed(OperationFailed::new(
124                        result.err().expect("checked above"),
125                    )))
126                }
127                AcceptorState::Retry | AcceptorState::NoAcceptorsMatched => {
128                    attempt += 1;
129
130                    let now = self.time_source.now();
131                    let elapsed = now.duration_since(start_time).unwrap_or_default();
132                    if !done_retrying && elapsed <= self.backoff.max_wait() {
133                        let delay = self.backoff.delay(attempt, elapsed);
134
135                        // The backoff function returns a zero delay when it is min_delay time away
136                        // from max_time. If we didn't detect this and stop polling, then we could
137                        // slam the server at the very end of the wait period for servers that are
138                        // really fast (for example, a few milliseconds total round-trip latency).
139                        if delay.is_zero() {
140                            tracing::debug!(
141                                "delay calculated for attempt #{attempt}; elapsed ({elapsed:?}); waiter is close to max time; will immediately poll one last time"
142                            );
143                            done_retrying = true;
144                        } else {
145                            tracing::debug!(
146                                "delay calculated for attempt #{attempt}; elapsed ({elapsed:?}); waiter will poll again in {delay:?}"
147                            );
148                            self.sleep_impl.sleep(delay).await;
149                        }
150                    } else {
151                        tracing::debug!(
152                            "waiter exceeded max wait time of {:?}",
153                            self.backoff.max_wait()
154                        );
155                        return Err(WaiterError::ExceededMaxWait(ExceededMaxWait::new(
156                            self.backoff.max_wait(),
157                            elapsed,
158                            attempt,
159                        )));
160                    }
161                }
162            }
163        }
164    }
165}
166
167/// Builder for [`WaiterOrchestrator`].
168#[derive(Default)]
169pub struct WaiterOrchestratorBuilder<AcceptorFn = (), OperationFn = ()> {
170    min_delay: Option<Duration>,
171    max_delay: Option<Duration>,
172    max_wait: Option<Duration>,
173    time_source: Option<SharedTimeSource>,
174    sleep_impl: Option<SharedAsyncSleep>,
175    random_fn: RandomImpl,
176    acceptor_fn: Option<AcceptorFn>,
177    operation_fn: Option<OperationFn>,
178}
179
180impl<AcceptorFn, OperationFn> WaiterOrchestratorBuilder<AcceptorFn, OperationFn> {
181    /// Set the minimum delay time for the waiter.
182    pub fn min_delay(mut self, min_delay: Duration) -> Self {
183        self.min_delay = Some(min_delay);
184        self
185    }
186
187    /// Set the maximum delay time for the waiter.
188    pub fn max_delay(mut self, max_delay: Duration) -> Self {
189        self.max_delay = Some(max_delay);
190        self
191    }
192
193    /// Set the maximum total wait time for the waiter.
194    pub fn max_wait(mut self, max_wait: Duration) -> Self {
195        self.max_wait = Some(max_wait);
196        self
197    }
198
199    #[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
200    fn random(mut self, random_fn: impl Fn(u64, u64) -> u64 + Send + Sync + 'static) -> Self {
201        self.random_fn = RandomImpl::Override(Box::new(random_fn));
202        self
203    }
204
205    /// Set the time source the waiter will use.
206    pub fn time_source(mut self, time_source: SharedTimeSource) -> Self {
207        self.time_source = Some(time_source);
208        self
209    }
210
211    /// Set the async sleep implementation the waiter will use to delay.
212    pub fn sleep_impl(mut self, sleep_impl: SharedAsyncSleep) -> Self {
213        self.sleep_impl = Some(sleep_impl);
214        self
215    }
216
217    /// Build a waiter orchestrator.
218    pub fn build(self) -> WaiterOrchestrator<AcceptorFn, OperationFn> {
219        WaiterOrchestrator::new(
220            Backoff::new(
221                self.min_delay.expect("min delay is required"),
222                self.max_delay.expect("max delay is required"),
223                self.max_wait.expect("max wait is required"),
224                self.random_fn,
225            ),
226            self.time_source.expect("time source required"),
227            self.sleep_impl.expect("sleep impl required"),
228            self.acceptor_fn.expect("acceptor fn required"),
229            self.operation_fn.expect("operation fn required"),
230        )
231    }
232}
233
234impl<OperationFn> WaiterOrchestratorBuilder<(), OperationFn> {
235    /// Set the acceptor function for the waiter.
236    pub fn acceptor<AcceptorFn>(
237        self,
238        acceptor: AcceptorFn,
239    ) -> WaiterOrchestratorBuilder<AcceptorFn, OperationFn> {
240        WaiterOrchestratorBuilder {
241            min_delay: self.min_delay,
242            max_delay: self.max_delay,
243            max_wait: self.max_wait,
244            time_source: self.time_source,
245            sleep_impl: self.sleep_impl,
246            random_fn: self.random_fn,
247            acceptor_fn: Some(acceptor),
248            operation_fn: self.operation_fn,
249        }
250    }
251}
252
253impl<AcceptorFn> WaiterOrchestratorBuilder<AcceptorFn, ()> {
254    /// Set the operation function for the waiter.
255    pub fn operation<OperationFn>(
256        self,
257        operation: OperationFn,
258    ) -> WaiterOrchestratorBuilder<AcceptorFn, OperationFn> {
259        WaiterOrchestratorBuilder {
260            min_delay: self.min_delay,
261            max_delay: self.max_delay,
262            max_wait: self.max_wait,
263            time_source: self.time_source,
264            sleep_impl: self.sleep_impl,
265            random_fn: self.random_fn,
266            acceptor_fn: self.acceptor_fn,
267            operation_fn: Some(operation),
268        }
269    }
270}
271
272/// Attaches a tracing span with a semi-unique waiter ID number so that all the operations
273/// made by the waiter can be correlated together in logs.
274pub fn attach_waiter_tracing_span<O, E>(
275    future: impl Future<Output = Result<FinalPoll<O, SdkError<E, HttpResponse>>, WaiterError<O, E>>>,
276) -> impl Future<Output = Result<FinalPoll<O, SdkError<E, HttpResponse>>, WaiterError<O, E>>> {
277    use tracing::Instrument;
278
279    // Create a random seven-digit ID for the waiter so that it can be correlated in the logs.
280    let span = tracing::debug_span!("waiter", waiter_id = fastrand::u32(1_000_000..10_000_000));
281    future.instrument(span)
282}
283
284#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
285mod tests {
286    use super::*;
287    use crate::test_util::capture_test_logs::show_test_logs;
288    use aws_smithy_async::{
289        test_util::tick_advance_sleep::tick_advance_time_and_sleep, time::TimeSource,
290    };
291    use aws_smithy_runtime_api::{http::StatusCode, shared::IntoShared};
292    use aws_smithy_types::body::SdkBody;
293    use std::{
294        fmt,
295        sync::{
296            atomic::{AtomicUsize, Ordering},
297            Arc, Mutex,
298        },
299        time::SystemTime,
300    };
301
302    #[derive(Debug)]
303    struct TestError;
304    impl std::error::Error for TestError {}
305    impl fmt::Display for TestError {
306        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307            f.write_str("TestError")
308        }
309    }
310    impl CreateUnhandledError for TestError {
311        fn create_unhandled_error(
312            _source: Box<dyn std::error::Error + Send + Sync + 'static>,
313            _meta: Option<aws_smithy_types::error::ErrorMetadata>,
314        ) -> Self {
315            unreachable!("If this is called, there is a bug in the orchestrator implementation. Unmodeled errors should never make it into FailureState.")
316        }
317    }
318
319    fn test_orchestrator(
320        sleep_impl: impl IntoShared<SharedAsyncSleep>,
321        time_source: impl IntoShared<SharedTimeSource>,
322    ) -> WaiterOrchestratorBuilder<(), ()> {
323        let test_random = |min: u64, max: u64| (min + max) / 2;
324        WaiterOrchestrator::builder()
325            .min_delay(Duration::from_secs(2))
326            .max_delay(Duration::from_secs(120))
327            .max_wait(Duration::from_secs(300))
328            .random(test_random)
329            .sleep_impl(sleep_impl.into_shared())
330            .time_source(time_source.into_shared())
331    }
332
333    #[tokio::test]
334    async fn immediate_success() {
335        let _logs = show_test_logs();
336        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
337        let orchestrator = test_orchestrator(sleep_impl, time_source)
338            .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Success)
339            .operation(|| async { Result::<_, SdkError<TestError, HttpResponse>>::Ok(5usize) })
340            .build();
341
342        let result = orchestrator.orchestrate().await;
343        assert!(result.is_ok());
344        assert_eq!(5, *result.unwrap().as_result().unwrap());
345    }
346
347    #[tokio::test]
348    async fn immediate_failure() {
349        let _logs = show_test_logs();
350        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
351        let orchestrator = test_orchestrator(sleep_impl, time_source)
352            .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Failure)
353            .operation(|| async { Result::<_, SdkError<TestError, HttpResponse>>::Ok(5usize) })
354            .build();
355
356        let result = orchestrator.orchestrate().await;
357        assert!(
358            matches!(result, Err(WaiterError::FailureState(_))),
359            "expected failure state, got: {result:?}"
360        );
361    }
362
363    #[tokio::test]
364    async fn five_polls_then_success() {
365        let _logs = show_test_logs();
366
367        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
368
369        let acceptor = |result: Result<&usize, &TestError>| match result {
370            Err(_) => unreachable!(),
371            Ok(5) => AcceptorState::Success,
372            _ => AcceptorState::Retry,
373        };
374
375        let times = Arc::new(Mutex::new(Vec::new()));
376        let attempt = Arc::new(AtomicUsize::new(1));
377        let operation = {
378            let sleep_impl = sleep_impl.clone();
379            let time_source = time_source.clone();
380            let times = times.clone();
381            move || {
382                let attempt = attempt.clone();
383                let sleep_impl = sleep_impl.clone();
384                let time_source = time_source.clone();
385                let times = times.clone();
386                async move {
387                    // simulate time passing for the network hop/service processing time
388                    sleep_impl.sleep(Duration::from_secs(1)).await;
389                    times.lock().unwrap().push(
390                        time_source
391                            .now()
392                            .duration_since(SystemTime::UNIX_EPOCH)
393                            .unwrap()
394                            .as_secs(),
395                    );
396                    Result::<_, SdkError<TestError, HttpResponse>>::Ok(
397                        attempt.fetch_add(1, Ordering::SeqCst),
398                    )
399                }
400            }
401        };
402
403        let orchestrator = test_orchestrator(sleep_impl.clone(), time_source.clone())
404            .acceptor(acceptor)
405            .operation(operation)
406            .build();
407
408        let task = tokio::spawn(orchestrator.orchestrate());
409        tokio::task::yield_now().await;
410        time_source.tick(Duration::from_secs(500)).await;
411        let result = task.await.unwrap();
412
413        assert!(result.is_ok());
414        assert_eq!(5, *result.unwrap().as_result().unwrap());
415        assert_eq!(vec![1, 4, 8, 14, 24], *times.lock().unwrap());
416    }
417
418    #[tokio::test]
419    async fn exceed_max_wait_time() {
420        let _logs = show_test_logs();
421        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
422
423        let orchestrator = test_orchestrator(sleep_impl.clone(), time_source.clone())
424            .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Retry)
425            .operation(|| async { Result::<_, SdkError<TestError, HttpResponse>>::Ok(1) })
426            .build();
427
428        let task = tokio::spawn(orchestrator.orchestrate());
429        tokio::task::yield_now().await;
430        time_source.tick(Duration::from_secs(500)).await;
431        let result = task.await.unwrap();
432
433        match result {
434            Err(WaiterError::ExceededMaxWait(context)) => {
435                assert_eq!(Duration::from_secs(300), context.max_wait());
436                assert_eq!(300, context.elapsed().as_secs());
437                assert_eq!(12, context.poll_count());
438            }
439            _ => panic!("expected ExceededMaxWait, got {result:?}"),
440        }
441    }
442
443    #[tokio::test]
444    async fn operation_timed_out() {
445        let _logs = show_test_logs();
446        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
447        let orchestrator = test_orchestrator(sleep_impl, time_source)
448            .acceptor(|_result: Result<&usize, &TestError>| unreachable!())
449            .operation(|| async {
450                Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::timeout_error(
451                    "test",
452                ))
453            })
454            .build();
455
456        match orchestrator.orchestrate().await {
457            Err(WaiterError::OperationFailed(err)) => match err.error() {
458                SdkError::TimeoutError(_) => { /* good */ }
459                result => panic!("unexpected final poll: {result:?}"),
460            },
461            result => panic!("unexpected result: {result:?}"),
462        }
463    }
464
465    #[tokio::test]
466    async fn modeled_service_error_no_acceptors_matched() {
467        let _logs = show_test_logs();
468        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
469        let orchestrator = test_orchestrator(sleep_impl, time_source)
470            .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::NoAcceptorsMatched)
471            .operation(|| async {
472                Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::service_error(
473                    TestError,
474                    HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty()),
475                ))
476            })
477            .build();
478
479        match dbg!(orchestrator.orchestrate().await) {
480            Err(WaiterError::OperationFailed(err)) => match err.error() {
481                SdkError::ServiceError(_) => { /* good */ }
482                result => panic!("unexpected result: {result:?}"),
483            },
484            result => panic!("unexpected result: {result:?}"),
485        }
486    }
487
488    #[tokio::test]
489    async fn modeled_error_matched_as_failure() {
490        let _logs = show_test_logs();
491        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
492        let orchestrator = test_orchestrator(sleep_impl, time_source)
493            .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Failure)
494            .operation(|| async {
495                Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::service_error(
496                    TestError,
497                    HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty()),
498                ))
499            })
500            .build();
501
502        match orchestrator.orchestrate().await {
503            Err(WaiterError::FailureState(err)) => match err.final_poll().as_result() {
504                Err(TestError) => { /* good */ }
505                result => panic!("unexpected final poll: {result:?}"),
506            },
507            result => panic!("unexpected result: {result:?}"),
508        }
509    }
510
511    #[tokio::test]
512    async fn modeled_error_matched_as_success() {
513        let _logs = show_test_logs();
514        let (time_source, sleep_impl) = tick_advance_time_and_sleep();
515        let orchestrator = test_orchestrator(sleep_impl, time_source)
516            .acceptor(|_result: Result<&usize, &TestError>| AcceptorState::Success)
517            .operation(|| async {
518                Result::<usize, SdkError<TestError, HttpResponse>>::Err(SdkError::service_error(
519                    TestError,
520                    HttpResponse::new(StatusCode::try_from(400).unwrap(), SdkBody::empty()),
521                ))
522            })
523            .build();
524
525        let result = orchestrator.orchestrate().await;
526        assert!(result.is_ok());
527        assert!(result.unwrap().as_result().is_err());
528    }
529}