futures_retry/stream.rs
1use crate::{ErrorHandler, RetryPolicy};
2use futures::{ready, Stream, TryStream};
3use pin_project_lite::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8};
9use tokio::time;
10
11pin_project! {
12 /// Provides a way to handle errors during a `Stream` execution, i.e. it gives you an ability to
13 /// poll for future stream's items with a delay.
14 ///
15 /// This type is similar to [`FutureRetry`](struct.FutureRetry.html), but with a different
16 /// semantics. For example, if for [`FutureRetry`](struct.FutureRetry.html) we need a factory that
17 /// creates `Future`s, we don't need one for `Stream`s, since `Stream` itself is a natural producer
18 /// of new items, so we don't have to recreated it if an error is encountered.
19 ///
20 /// A typical usage might be recovering from connection errors while trying to accept a connection
21 /// on a TCP server.
22 ///
23 /// A `tcp-listener` example is available in the `examples` folder.
24 ///
25 /// Also have a look at [`StreamRetryExt`](trait.StreamRetryExt.html) trait for a more convenient
26 /// usage.
27 pub struct StreamRetry<F, S> {
28 error_action: F,
29 #[pin]
30 stream: S,
31 attempt: usize,
32 #[pin]
33 state: RetryState,
34 }
35}
36
37/// An extention trait for `Stream` which allows to use `StreamRetry` in a chain-like manner.
38///
39/// # Example
40///
41/// This magic trait allows you to handle errors on streams in a very neat manner:
42///
43/// ```
44/// // ...
45/// use futures_retry::{RetryPolicy, StreamRetryExt};
46/// # use futures::{TryStreamExt, TryFutureExt, future::{ok, select}, FutureExt, stream};
47/// # use std::io;
48/// # use std::time::Duration;
49/// # use tokio::net::{TcpListener, TcpStream};
50///
51/// fn handle_error(e: io::Error) -> RetryPolicy<io::Error> {
52/// match e.kind() {
53/// io::ErrorKind::Interrupted => RetryPolicy::Repeat,
54/// io::ErrorKind::PermissionDenied => RetryPolicy::ForwardError(e),
55/// _ => RetryPolicy::WaitRetry(Duration::from_millis(5)),
56/// }
57/// }
58///
59/// async fn serve_connection(stream: TcpStream) {
60/// // ...
61/// }
62///
63/// #[tokio::main]
64/// async fn main() {
65/// let mut listener: TcpListener = // ...
66/// # TcpListener::bind("[::]:0").await.unwrap();
67/// let server = stream::try_unfold(listener, |listener| async move {
68/// Ok(Some((listener.accept().await?.0, listener)))
69/// })
70/// .retry(handle_error)
71/// .and_then(|(stream, _attempt)| {
72/// tokio::spawn(serve_connection(stream));
73/// ok(())
74/// })
75/// .try_for_each(|_| ok(()))
76/// .map_err(|(e, _attempt)| eprintln!("Caught an error {}", e));
77///
78/// # // This nasty hack is required to exit immediately when running the doc tests.
79/// # futures::pin_mut!(server);
80/// # let server = select(ok::<_, ()>(()), server).map(|_| ());
81/// server.await
82/// }
83/// ```
84pub trait StreamRetryExt: TryStream {
85 /// Converts the stream into a **retry stream**. See `StreamRetry::new` for details.
86 fn retry<F>(self, error_action: F) -> StreamRetry<F, Self>
87 where
88 Self: Sized,
89 {
90 StreamRetry::new(self, error_action)
91 }
92}
93
94impl<S: ?Sized> StreamRetryExt for S where S: TryStream {}
95
96pin_project! {
97 #[project = RetryStateProj]
98 enum RetryState {
99 WaitingForStream,
100 TimerActive { #[pin] delay: time::Sleep },
101 }
102}
103
104impl<F, S> StreamRetry<F, S> {
105 /// Creates a `StreamRetry` using a provided stream and an object of `ErrorHandler` type that
106 /// decides on a retry-policy depending on an encountered error.
107 ///
108 /// Please refer to the `tcp-listener` example in the `examples` folder to have a look at a
109 /// possible usage or to a very convenient extension trait
110 /// [`StreamRetryExt`](trait.StreamRetryExt.html).
111 ///
112 /// # Arguments
113 ///
114 /// * `stream`: a stream of future items,
115 /// * `error_action`: a type that handles an error and decides which route to take: simply
116 /// try again, wait and then try, or give up (on a critical error for
117 /// exapmle).
118 pub fn new(stream: S, error_action: F) -> Self
119 where
120 S: TryStream,
121 {
122 Self::with_counter(stream, error_action, 1)
123 }
124
125 /// Like a `new` method, but a custom attempt counter initial value might be provided.
126 pub fn with_counter(stream: S, error_action: F, attempt_counter: usize) -> Self {
127 Self {
128 error_action,
129 stream,
130 attempt: attempt_counter,
131 state: RetryState::WaitingForStream,
132 }
133 }
134}
135
136impl<F, S> Stream for StreamRetry<F, S>
137where
138 S: TryStream,
139 F: ErrorHandler<S::Error>,
140{
141 type Item = Result<(S::Ok, usize), (F::OutError, usize)>;
142
143 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
144 loop {
145 let this = self.as_mut().project();
146 let attempt = *this.attempt;
147 let new_state = match this.state.project() {
148 RetryStateProj::TimerActive { delay } => {
149 ready!(delay.poll(cx));
150 RetryState::WaitingForStream
151 }
152 RetryStateProj::WaitingForStream => match ready!(this.stream.try_poll_next(cx)) {
153 Some(Ok(x)) => {
154 *this.attempt = 1;
155 this.error_action.ok(attempt);
156 return Poll::Ready(Some(Ok((x, attempt))));
157 }
158 None => {
159 return Poll::Ready(None);
160 }
161 Some(Err(e)) => {
162 *this.attempt += 1;
163 match this.error_action.handle(attempt, e) {
164 RetryPolicy::ForwardError(e) => {
165 return Poll::Ready(Some(Err((e, attempt))))
166 }
167 RetryPolicy::Repeat => RetryState::WaitingForStream,
168 RetryPolicy::WaitRetry(duration) => RetryState::TimerActive {
169 delay: time::sleep(duration),
170 },
171 }
172 }
173 },
174 };
175 self.as_mut().project().state.set(new_state);
176 }
177 }
178}
179
180#[cfg(test)]
181mod test {
182 use super::*;
183 use futures::{pin_mut, prelude::*};
184 use std::time::Duration;
185
186 #[tokio::test]
187 async fn naive() {
188 let stream = stream::iter(vec![Ok::<_, u8>(17u8), Ok(19u8)]);
189 let retry = StreamRetry::new(stream, |_| RetryPolicy::Repeat::<()>);
190 assert_eq!(
191 Ok(vec![(17, 1), (19, 1)]),
192 retry.try_collect::<Vec<_>>().await,
193 );
194 }
195
196 #[tokio::test]
197 async fn repeat() {
198 let stream = stream::iter(vec![Ok(1), Err(17), Ok(19)]);
199 let retry = StreamRetry::new(stream, |_| RetryPolicy::Repeat::<()>);
200 assert_eq!(
201 Ok(vec![(1, 1), (19, 2)]),
202 retry.try_collect::<Vec<_>>().await,
203 );
204 }
205
206 #[tokio::test]
207 async fn wait() {
208 let stream = stream::iter(vec![Err(17), Ok(19)]);
209 let retry = StreamRetry::new(stream, |_| {
210 RetryPolicy::WaitRetry::<()>(Duration::from_millis(10))
211 })
212 .try_collect()
213 .into_future();
214 assert_eq!(Ok(vec!((19, 2))), retry.await);
215 }
216
217 #[tokio::test]
218 async fn propagate() {
219 let stream = stream::iter(vec![Err(17u8), Ok(19u16)]);
220 let retry = StreamRetry::new(stream, RetryPolicy::ForwardError);
221 pin_mut!(retry);
222 assert_eq!(Some(Err((17u8, 1))), retry.next().await,);
223 }
224}