1use crate::{ErrorHandler, RetryPolicy};
2use futures::{ready, TryFuture};
3use pin_project_lite::pin_project;
4use std::{
5 future::Future,
6 marker::Unpin,
7 pin::Pin,
8 task::{Context, Poll},
9};
10use tokio::time;
11
12pub trait FutureFactory {
20 type FutureItem: TryFuture;
22
23 fn new(&mut self) -> Self::FutureItem;
26}
27
28impl<T, F> FutureFactory for T
29where
30 T: Unpin + FnMut() -> F,
31 F: TryFuture,
32{
33 type FutureItem = F;
34
35 #[allow(clippy::new_ret_no_self)]
36 fn new(&mut self) -> F {
37 (self)()
38 }
39}
40
41pin_project! {
42 pub struct FutureRetry<F, R>
50 where
51 F: FutureFactory,
52 {
53 factory: F,
54 error_action: R,
55 attempt: usize,
56 #[pin]
57 state: RetryState<F::FutureItem>,
58 }
59}
60
61pin_project! {
62 #[project = RetryStateProj]
63 enum RetryState<F> {
64 NotStarted,
65 WaitingForFuture { #[pin] future: F },
66 TimerActive { #[pin] delay: time::Sleep },
67 }
68}
69
70impl<F: FutureFactory, R> FutureRetry<F, R> {
71 pub fn new(factory: F, error_action: R) -> Self {
84 Self {
85 factory,
86 error_action,
87 state: RetryState::NotStarted,
88 attempt: 1,
89 }
90 }
91}
92
93impl<F: FutureFactory, R> Future for FutureRetry<F, R>
94where
95 R: ErrorHandler<<F::FutureItem as TryFuture>::Error>,
96{
97 type Output =
98 Result<(<<F as FutureFactory>::FutureItem as TryFuture>::Ok, usize), (R::OutError, usize)>;
99
100 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
101 loop {
102 let this = self.as_mut().project();
103 let attempt = *this.attempt;
104 let new_state = match this.state.project() {
105 RetryStateProj::NotStarted => RetryState::WaitingForFuture {
106 future: this.factory.new(),
107 },
108 RetryStateProj::TimerActive { delay } => {
109 ready!(delay.poll(cx));
110 RetryState::WaitingForFuture {
111 future: this.factory.new(),
112 }
113 }
114 RetryStateProj::WaitingForFuture { future } => match ready!(future.try_poll(cx)) {
115 Ok(x) => {
116 this.error_action.ok(attempt);
117 *this.attempt = 1;
118 return Poll::Ready(Ok((x, attempt)));
119 }
120 Err(e) => {
121 *this.attempt += 1;
122 match this.error_action.handle(attempt, e) {
123 RetryPolicy::ForwardError(e) => return Poll::Ready(Err((e, attempt))),
124 RetryPolicy::Repeat => RetryState::WaitingForFuture {
125 future: this.factory.new(),
126 },
127 RetryPolicy::WaitRetry(duration) => RetryState::TimerActive {
128 delay: time::sleep(duration),
129 },
130 }
131 }
132 },
133 };
134
135 self.as_mut().project().state.set(new_state);
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use futures::{
144 future::{err, ok},
145 TryFutureExt,
146 };
147 use std::time::Duration;
148
149 struct FutureIterator<F>(F);
151
152 impl<I, F> FutureFactory for FutureIterator<I>
153 where
154 I: Unpin + Iterator<Item = F>,
155 F: TryFuture,
156 {
157 type FutureItem = F;
158
159 fn new(&mut self) -> Self::FutureItem {
163 self.0.next().expect("No more futures!")
164 }
165 }
166
167 #[tokio::test]
168 async fn naive() {
169 let f = FutureRetry::new(|| ok::<_, u8>(1u8), |_| RetryPolicy::Repeat::<u8>);
170 assert_eq!(Ok((1u8, 1)), f.await);
171 }
172
173 #[tokio::test]
174 async fn naive_error_forward() {
175 let f = FutureRetry::new(|| err::<u8, _>(1u8), RetryPolicy::ForwardError);
176 assert_eq!(Err((1u8, 1)), f.await);
177 }
178
179 #[tokio::test]
180 async fn more_complicated_wait() {
181 let f = FutureRetry::new(FutureIterator(vec![err(2u8), ok(3u8)].into_iter()), |_| {
182 RetryPolicy::WaitRetry::<u8>(Duration::from_millis(10))
183 })
184 .into_future();
185 assert_eq!(Ok((3, 2)), f.await);
186 }
187
188 #[tokio::test]
189 async fn more_complicated_repeat() {
190 let f = FutureRetry::new(FutureIterator(vec![err(2u8), ok(3u8)].into_iter()), |_| {
191 RetryPolicy::Repeat::<u8>
192 });
193 assert_eq!(Ok((3u8, 2)), f.await);
194 }
195
196 #[tokio::test]
197 async fn more_complicated_forward() {
198 let f = FutureRetry::new(
199 FutureIterator(vec![err(2u8), ok(3u8)].into_iter()),
200 RetryPolicy::ForwardError,
201 );
202 assert_eq!(Err((2u8, 1)), f.await);
203 }
204}