futures_retry/
stream.rs

1use crate::{ErrorHandler, RetryPolicy};
2use futures::{ready, Stream, TryStream};
3use pin_project_lite::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use tokio::time;
10
11pin_project! {
12    /// Provides a way to handle errors during a `Stream` execution, i.e. it gives you an ability to
13    /// poll for future stream's items with a delay.
14    ///
15    /// This type is similar to [`FutureRetry`](struct.FutureRetry.html), but with a different
16    /// semantics. For example, if for [`FutureRetry`](struct.FutureRetry.html) we need a factory that
17    /// creates `Future`s, we don't need one for `Stream`s, since `Stream` itself is a natural producer
18    /// of new items, so we don't have to recreated it if an error is encountered.
19    ///
20    /// A typical usage might be recovering from connection errors while trying to accept a connection
21    /// on a TCP server.
22    ///
23    /// A `tcp-listener` example is available in the `examples` folder.
24    ///
25    /// Also have a look at [`StreamRetryExt`](trait.StreamRetryExt.html) trait for a more convenient
26    /// usage.
27    pub struct StreamRetry<F, S> {
28        error_action: F,
29        #[pin]
30        stream: S,
31        attempt: usize,
32        #[pin]
33        state: RetryState,
34    }
35}
36
37/// An extention trait for `Stream` which allows to use `StreamRetry` in a chain-like manner.
38///
39/// # Example
40///
41/// This magic trait allows you to handle errors on streams in a very neat manner:
42///
43/// ```
44/// // ...
45/// use futures_retry::{RetryPolicy, StreamRetryExt};
46/// # use futures::{TryStreamExt, TryFutureExt, future::{ok, select}, FutureExt, stream};
47/// # use std::io;
48/// # use std::time::Duration;
49/// # use tokio::net::{TcpListener, TcpStream};
50///
51/// fn handle_error(e: io::Error) -> RetryPolicy<io::Error> {
52///   match e.kind() {
53///     io::ErrorKind::Interrupted => RetryPolicy::Repeat,
54///     io::ErrorKind::PermissionDenied => RetryPolicy::ForwardError(e),
55///     _ => RetryPolicy::WaitRetry(Duration::from_millis(5)),
56///   }
57/// }
58///
59/// async fn serve_connection(stream: TcpStream) {
60///   // ...
61/// }
62///
63/// #[tokio::main]
64/// async fn main() {
65///   let mut listener: TcpListener = // ...
66///   # TcpListener::bind("[::]:0").await.unwrap();
67///   let server = stream::try_unfold(listener, |listener| async move {
68///     Ok(Some((listener.accept().await?.0, listener)))
69///   })
70///   .retry(handle_error)
71///   .and_then(|(stream, _attempt)| {
72///     tokio::spawn(serve_connection(stream));
73///     ok(())
74///   })
75///   .try_for_each(|_| ok(()))
76///   .map_err(|(e, _attempt)| eprintln!("Caught an error {}", e));
77///
78///   # // This nasty hack is required to exit immediately when running the doc tests.
79///   # futures::pin_mut!(server);
80///   # let server = select(ok::<_, ()>(()), server).map(|_| ());
81///   server.await
82/// }
83/// ```
84pub trait StreamRetryExt: TryStream {
85    /// Converts the stream into a **retry stream**. See `StreamRetry::new` for details.
86    fn retry<F>(self, error_action: F) -> StreamRetry<F, Self>
87    where
88        Self: Sized,
89    {
90        StreamRetry::new(self, error_action)
91    }
92}
93
94impl<S: ?Sized> StreamRetryExt for S where S: TryStream {}
95
96pin_project! {
97    #[project = RetryStateProj]
98    enum RetryState {
99        WaitingForStream,
100        TimerActive { #[pin] delay: time::Sleep },
101    }
102}
103
104impl<F, S> StreamRetry<F, S> {
105    /// Creates a `StreamRetry` using a provided stream and an object of `ErrorHandler` type that
106    /// decides on a retry-policy depending on an encountered error.
107    ///
108    /// Please refer to the `tcp-listener` example in the `examples` folder to have a look at a
109    /// possible usage or to a very convenient extension trait
110    /// [`StreamRetryExt`](trait.StreamRetryExt.html).
111    ///
112    /// # Arguments
113    ///
114    /// * `stream`: a stream of future items,
115    /// * `error_action`: a type that handles an error and decides which route to take: simply
116    ///                   try again, wait and then try, or give up (on a critical error for
117    ///                   exapmle).
118    pub fn new(stream: S, error_action: F) -> Self
119    where
120        S: TryStream,
121    {
122        Self::with_counter(stream, error_action, 1)
123    }
124
125    /// Like a `new` method, but a custom attempt counter initial value might be provided.
126    pub fn with_counter(stream: S, error_action: F, attempt_counter: usize) -> Self {
127        Self {
128            error_action,
129            stream,
130            attempt: attempt_counter,
131            state: RetryState::WaitingForStream,
132        }
133    }
134}
135
136impl<F, S> Stream for StreamRetry<F, S>
137where
138    S: TryStream,
139    F: ErrorHandler<S::Error>,
140{
141    type Item = Result<(S::Ok, usize), (F::OutError, usize)>;
142
143    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
144        loop {
145            let this = self.as_mut().project();
146            let attempt = *this.attempt;
147            let new_state = match this.state.project() {
148                RetryStateProj::TimerActive { delay } => {
149                    ready!(delay.poll(cx));
150                    RetryState::WaitingForStream
151                }
152                RetryStateProj::WaitingForStream => match ready!(this.stream.try_poll_next(cx)) {
153                    Some(Ok(x)) => {
154                        *this.attempt = 1;
155                        this.error_action.ok(attempt);
156                        return Poll::Ready(Some(Ok((x, attempt))));
157                    }
158                    None => {
159                        return Poll::Ready(None);
160                    }
161                    Some(Err(e)) => {
162                        *this.attempt += 1;
163                        match this.error_action.handle(attempt, e) {
164                            RetryPolicy::ForwardError(e) => {
165                                return Poll::Ready(Some(Err((e, attempt))))
166                            }
167                            RetryPolicy::Repeat => RetryState::WaitingForStream,
168                            RetryPolicy::WaitRetry(duration) => RetryState::TimerActive {
169                                delay: time::sleep(duration),
170                            },
171                        }
172                    }
173                },
174            };
175            self.as_mut().project().state.set(new_state);
176        }
177    }
178}
179
180#[cfg(test)]
181mod test {
182    use super::*;
183    use futures::{pin_mut, prelude::*};
184    use std::time::Duration;
185
186    #[tokio::test]
187    async fn naive() {
188        let stream = stream::iter(vec![Ok::<_, u8>(17u8), Ok(19u8)]);
189        let retry = StreamRetry::new(stream, |_| RetryPolicy::Repeat::<()>);
190        assert_eq!(
191            Ok(vec![(17, 1), (19, 1)]),
192            retry.try_collect::<Vec<_>>().await,
193        );
194    }
195
196    #[tokio::test]
197    async fn repeat() {
198        let stream = stream::iter(vec![Ok(1), Err(17), Ok(19)]);
199        let retry = StreamRetry::new(stream, |_| RetryPolicy::Repeat::<()>);
200        assert_eq!(
201            Ok(vec![(1, 1), (19, 2)]),
202            retry.try_collect::<Vec<_>>().await,
203        );
204    }
205
206    #[tokio::test]
207    async fn wait() {
208        let stream = stream::iter(vec![Err(17), Ok(19)]);
209        let retry = StreamRetry::new(stream, |_| {
210            RetryPolicy::WaitRetry::<()>(Duration::from_millis(10))
211        })
212        .try_collect()
213        .into_future();
214        assert_eq!(Ok(vec!((19, 2))), retry.await);
215    }
216
217    #[tokio::test]
218    async fn propagate() {
219        let stream = stream::iter(vec![Err(17u8), Ok(19u16)]);
220        let retry = StreamRetry::new(stream, RetryPolicy::ForwardError);
221        pin_mut!(retry);
222        assert_eq!(Some(Err((17u8, 1))), retry.next().await,);
223    }
224}