1use std::future::Future;
2use std::hash::Hash;
3use std::pin::Pin;
4use std::task::{Context, Poll, Waker};
5use std::time::Duration;
6use std::{future, mem};
7
8use futures_timer::Delay;
9use futures_util::future::BoxFuture;
10use futures_util::stream::FuturesUnordered;
11use futures_util::{FutureExt, StreamExt};
12
13use crate::{PushError, Timeout};
14
15pub struct FuturesMap<ID, O> {
19 timeout: Duration,
20 capacity: usize,
21 inner: FuturesUnordered<TaggedFuture<ID, TimeoutFuture<BoxFuture<'static, O>>>>,
22 empty_waker: Option<Waker>,
23 full_waker: Option<Waker>,
24}
25
26impl<ID, O> FuturesMap<ID, O> {
27 pub fn new(timeout: Duration, capacity: usize) -> Self {
28 Self {
29 timeout,
30 capacity,
31 inner: Default::default(),
32 empty_waker: None,
33 full_waker: None,
34 }
35 }
36}
37
38impl<ID, O> FuturesMap<ID, O>
39where
40 ID: Clone + Hash + Eq + Send + Unpin + 'static,
41 O: 'static,
42{
43 pub fn try_push<F>(&mut self, future_id: ID, future: F) -> Result<(), PushError<BoxFuture<O>>>
51 where
52 F: Future<Output = O> + Send + 'static,
53 {
54 if self.inner.len() >= self.capacity {
55 return Err(PushError::BeyondCapacity(future.boxed()));
56 }
57
58 if let Some(waker) = self.empty_waker.take() {
59 waker.wake();
60 }
61
62 let old = self.remove(future_id.clone());
63 self.inner.push(TaggedFuture {
64 tag: future_id,
65 inner: TimeoutFuture {
66 inner: future.boxed(),
67 timeout: Delay::new(self.timeout),
68 cancelled: false,
69 },
70 });
71 match old {
72 None => Ok(()),
73 Some(old) => Err(PushError::Replaced(old)),
74 }
75 }
76
77 pub fn remove(&mut self, id: ID) -> Option<BoxFuture<'static, O>> {
78 let tagged = self.inner.iter_mut().find(|s| s.tag == id)?;
79
80 let inner = mem::replace(&mut tagged.inner.inner, future::pending().boxed());
81 tagged.inner.cancelled = true;
82
83 Some(inner)
84 }
85
86 pub fn contains(&self, id: ID) -> bool {
87 self.inner.iter().any(|f| f.tag == id && !f.inner.cancelled)
88 }
89
90 pub fn len(&self) -> usize {
91 self.inner.len()
92 }
93
94 pub fn is_empty(&self) -> bool {
95 self.inner.is_empty()
96 }
97
98 #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
100 if self.inner.len() < self.capacity {
101 return Poll::Ready(());
102 }
103
104 self.full_waker = Some(cx.waker().clone());
105
106 Poll::Pending
107 }
108
109 pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
110 loop {
111 let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
112
113 match maybe_result {
114 None => {
115 self.empty_waker = Some(cx.waker().clone());
116 return Poll::Pending;
117 }
118 Some((id, Ok(output))) => return Poll::Ready((id, Ok(output))),
119 Some((id, Err(TimeoutError::Timeout))) => {
120 return Poll::Ready((id, Err(Timeout::new(self.timeout))))
121 }
122 Some((_, Err(TimeoutError::Cancelled))) => continue,
123 }
124 }
125 }
126}
127
128struct TimeoutFuture<F> {
129 inner: F,
130 timeout: Delay,
131
132 cancelled: bool,
133}
134
135impl<F> Future for TimeoutFuture<F>
136where
137 F: Future + Unpin,
138{
139 type Output = Result<F::Output, TimeoutError>;
140
141 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142 if self.cancelled {
143 return Poll::Ready(Err(TimeoutError::Cancelled));
144 }
145
146 if self.timeout.poll_unpin(cx).is_ready() {
147 return Poll::Ready(Err(TimeoutError::Timeout));
148 }
149
150 self.inner.poll_unpin(cx).map(Ok)
151 }
152}
153
154enum TimeoutError {
155 Timeout,
156 Cancelled,
157}
158
159struct TaggedFuture<T, F> {
160 tag: T,
161 inner: F,
162}
163
164impl<T, F> Future for TaggedFuture<T, F>
165where
166 T: Clone + Unpin,
167 F: Future + Unpin,
168{
169 type Output = (T, F::Output);
170
171 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172 let output = futures_util::ready!(self.inner.poll_unpin(cx));
173
174 Poll::Ready((self.tag.clone(), output))
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use futures::channel::oneshot;
181 use futures_util::task::noop_waker_ref;
182 use std::future::{pending, poll_fn, ready};
183 use std::pin::Pin;
184 use std::time::Instant;
185
186 use super::*;
187
188 #[test]
189 fn cannot_push_more_than_capacity_tasks() {
190 let mut futures = FuturesMap::new(Duration::from_secs(10), 1);
191
192 assert!(futures.try_push("ID_1", ready(())).is_ok());
193 matches!(
194 futures.try_push("ID_2", ready(())),
195 Err(PushError::BeyondCapacity(_))
196 );
197 }
198
199 #[test]
200 fn cannot_push_the_same_id_few_times() {
201 let mut futures = FuturesMap::new(Duration::from_secs(10), 5);
202
203 assert!(futures.try_push("ID", ready(())).is_ok());
204 matches!(
205 futures.try_push("ID", ready(())),
206 Err(PushError::Replaced(_))
207 );
208 }
209
210 #[tokio::test]
211 async fn futures_timeout() {
212 let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
213
214 let _ = futures.try_push("ID", pending::<()>());
215 Delay::new(Duration::from_millis(150)).await;
216 let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await;
217
218 assert!(result.is_err())
219 }
220
221 #[test]
222 fn resources_of_removed_future_are_cleaned_up() {
223 let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
224
225 let _ = futures.try_push("ID", pending::<()>());
226 futures.remove("ID");
227
228 let poll = futures.poll_unpin(&mut Context::from_waker(noop_waker_ref()));
229 assert!(poll.is_pending());
230
231 assert_eq!(futures.len(), 0);
232 }
233
234 #[tokio::test]
235 async fn replaced_pending_future_is_polled() {
236 let mut streams = FuturesMap::new(Duration::from_millis(100), 3);
237
238 let (_tx1, rx1) = oneshot::channel();
239 let (tx2, rx2) = oneshot::channel();
240
241 let _ = streams.try_push("ID1", rx1);
242 let _ = streams.try_push("ID2", rx2);
243
244 let _ = tx2.send(2);
245 let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
246 assert_eq!(id, "ID2");
247 assert_eq!(res.unwrap().unwrap(), 2);
248
249 let (new_tx1, new_rx1) = oneshot::channel();
250 let replaced = streams.try_push("ID1", new_rx1);
251 assert!(matches!(replaced.unwrap_err(), PushError::Replaced(_)));
252
253 let _ = new_tx1.send(4);
254 let (id, res) = poll_fn(|cx| streams.poll_unpin(cx)).await;
255
256 assert_eq!(id, "ID1");
257 assert_eq!(res.unwrap().unwrap(), 4);
258 }
259
260 #[tokio::test]
263 async fn backpressure() {
264 const DELAY: Duration = Duration::from_millis(100);
265 const NUM_FUTURES: u32 = 10;
266
267 let start = Instant::now();
268 Task::new(DELAY, NUM_FUTURES, 1).await;
269 let duration = start.elapsed();
270
271 assert!(duration >= DELAY * NUM_FUTURES);
272 }
273
274 #[test]
275 fn contains() {
276 let mut futures = FuturesMap::new(Duration::from_secs(10), 1);
277 _ = futures.try_push("ID", pending::<()>());
278 assert!(futures.contains("ID"));
279 _ = futures.remove("ID");
280 assert!(!futures.contains("ID"));
281 }
282
283 struct Task {
284 future: Duration,
285 num_futures: usize,
286 num_processed: usize,
287 inner: FuturesMap<u8, ()>,
288 }
289
290 impl Task {
291 fn new(future: Duration, num_futures: u32, capacity: usize) -> Self {
292 Self {
293 future,
294 num_futures: num_futures as usize,
295 num_processed: 0,
296 inner: FuturesMap::new(Duration::from_secs(60), capacity),
297 }
298 }
299 }
300
301 impl Future for Task {
302 type Output = ();
303
304 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
305 let this = self.get_mut();
306
307 while this.num_processed < this.num_futures {
308 if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) {
309 if result.is_err() {
310 panic!("Timeout is great than future delay")
311 }
312
313 this.num_processed += 1;
314 continue;
315 }
316
317 if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
318 let maybe_future = this.inner.try_push(1u8, Delay::new(this.future));
321 assert!(maybe_future.is_ok(), "we polled for readiness");
322
323 continue;
324 }
325
326 return Poll::Pending;
327 }
328
329 Poll::Ready(())
330 }
331 }
332}