futures_retry/
future.rs

1use crate::{ErrorHandler, RetryPolicy};
2use futures::{ready, TryFuture};
3use pin_project_lite::pin_project;
4use std::{
5    future::Future,
6    marker::Unpin,
7    pin::Pin,
8    task::{Context, Poll},
9};
10use tokio::time;
11
12/// A factory trait used to create futures.
13///
14/// We need a factory for the retry logic because when (and if) a future returns an error, its
15/// internal state is undefined and we can't poll on it anymore. Hence we need to create a new one.
16///
17/// By the way, this trait is implemented for any closure that returns a `Future`, so you don't
18/// have to write your own type and implement it to handle some simple cases.
19pub trait FutureFactory {
20    /// An future type that is created by the `new` method.
21    type FutureItem: TryFuture;
22
23    /// Creates a new future. We don't need the factory to be immutable so we pass `self` as a
24    /// mutable reference.
25    fn new(&mut self) -> Self::FutureItem;
26}
27
28impl<T, F> FutureFactory for T
29where
30    T: Unpin + FnMut() -> F,
31    F: TryFuture,
32{
33    type FutureItem = F;
34
35    #[allow(clippy::new_ret_no_self)]
36    fn new(&mut self) -> F {
37        (self)()
38    }
39}
40
41pin_project! {
42    /// A future that transparently launches an underlying future (created by a provided factory each
43    /// time) as many times as needed to get things done.
44    ///
45    /// It is useful fot situations when you need to make several attempts, e.g. for establishing
46    /// connections, RPC calls.
47    ///
48    /// There is also a type to handle `Stream` errors: [`StreamRetry`](struct.StreamRetry.html).#[pin_project]
49    pub struct FutureRetry<F, R>
50    where
51        F: FutureFactory,
52    {
53        factory: F,
54        error_action: R,
55        attempt: usize,
56        #[pin]
57        state: RetryState<F::FutureItem>,
58    }
59}
60
61pin_project! {
62    #[project = RetryStateProj]
63    enum RetryState<F> {
64        NotStarted,
65        WaitingForFuture { #[pin] future: F },
66        TimerActive { #[pin] delay: time::Sleep },
67    }
68}
69
70impl<F: FutureFactory, R> FutureRetry<F, R> {
71    /// Creates a `FutureRetry` using a provided factory and an object of `ErrorHandler` type that
72    /// decides on a retry-policy depending on an encountered error.
73    ///
74    /// Please refer to the `tcp-client` example in the `examples` folder to have a look at a
75    /// possible usage.
76    ///
77    /// # Arguments
78    ///
79    /// * `factory`: a factory that creates futures,
80    /// * `error_action`: a type that handles an error and decides which route to take: simply
81    ///                   try again, wait and then try, or give up (on a critical error for
82    ///                   exapmle).
83    pub fn new(factory: F, error_action: R) -> Self {
84        Self {
85            factory,
86            error_action,
87            state: RetryState::NotStarted,
88            attempt: 1,
89        }
90    }
91}
92
93impl<F: FutureFactory, R> Future for FutureRetry<F, R>
94where
95    R: ErrorHandler<<F::FutureItem as TryFuture>::Error>,
96{
97    type Output =
98        Result<(<<F as FutureFactory>::FutureItem as TryFuture>::Ok, usize), (R::OutError, usize)>;
99
100    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
101        loop {
102            let this = self.as_mut().project();
103            let attempt = *this.attempt;
104            let new_state = match this.state.project() {
105                RetryStateProj::NotStarted => RetryState::WaitingForFuture {
106                    future: this.factory.new(),
107                },
108                RetryStateProj::TimerActive { delay } => {
109                    ready!(delay.poll(cx));
110                    RetryState::WaitingForFuture {
111                        future: this.factory.new(),
112                    }
113                }
114                RetryStateProj::WaitingForFuture { future } => match ready!(future.try_poll(cx)) {
115                    Ok(x) => {
116                        this.error_action.ok(attempt);
117                        *this.attempt = 1;
118                        return Poll::Ready(Ok((x, attempt)));
119                    }
120                    Err(e) => {
121                        *this.attempt += 1;
122                        match this.error_action.handle(attempt, e) {
123                            RetryPolicy::ForwardError(e) => return Poll::Ready(Err((e, attempt))),
124                            RetryPolicy::Repeat => RetryState::WaitingForFuture {
125                                future: this.factory.new(),
126                            },
127                            RetryPolicy::WaitRetry(duration) => RetryState::TimerActive {
128                                delay: time::sleep(duration),
129                            },
130                        }
131                    }
132                },
133            };
134
135            self.as_mut().project().state.set(new_state);
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use futures::{
144        future::{err, ok},
145        TryFutureExt,
146    };
147    use std::time::Duration;
148
149    /// Just a help type for the tests.
150    struct FutureIterator<F>(F);
151
152    impl<I, F> FutureFactory for FutureIterator<I>
153    where
154        I: Unpin + Iterator<Item = F>,
155        F: TryFuture,
156    {
157        type FutureItem = F;
158
159        /// # Warning
160        ///
161        /// Will panic if there is no *next* future.
162        fn new(&mut self) -> Self::FutureItem {
163            self.0.next().expect("No more futures!")
164        }
165    }
166
167    #[tokio::test]
168    async fn naive() {
169        let f = FutureRetry::new(|| ok::<_, u8>(1u8), |_| RetryPolicy::Repeat::<u8>);
170        assert_eq!(Ok((1u8, 1)), f.await);
171    }
172
173    #[tokio::test]
174    async fn naive_error_forward() {
175        let f = FutureRetry::new(|| err::<u8, _>(1u8), RetryPolicy::ForwardError);
176        assert_eq!(Err((1u8, 1)), f.await);
177    }
178
179    #[tokio::test]
180    async fn more_complicated_wait() {
181        let f = FutureRetry::new(FutureIterator(vec![err(2u8), ok(3u8)].into_iter()), |_| {
182            RetryPolicy::WaitRetry::<u8>(Duration::from_millis(10))
183        })
184        .into_future();
185        assert_eq!(Ok((3, 2)), f.await);
186    }
187
188    #[tokio::test]
189    async fn more_complicated_repeat() {
190        let f = FutureRetry::new(FutureIterator(vec![err(2u8), ok(3u8)].into_iter()), |_| {
191            RetryPolicy::Repeat::<u8>
192        });
193        assert_eq!(Ok((3u8, 2)), f.await);
194    }
195
196    #[tokio::test]
197    async fn more_complicated_forward() {
198        let f = FutureRetry::new(
199            FutureIterator(vec![err(2u8), ok(3u8)].into_iter()),
200            RetryPolicy::ForwardError,
201        );
202        assert_eq!(Err((2u8, 1)), f.await);
203    }
204}