1use futures::{stream::Fuse, Stream, StreamExt};
4use hashbrown::{hash_map::RawEntryMut, HashMap};
5use pin_project::pin_project;
6use std::{
7 collections::HashSet,
8 hash::Hash,
9 pin::Pin,
10 task::{Context, Poll},
11 time::Duration,
12};
13use tokio::time::Instant;
14use tokio_util::time::delay_queue::{self, DelayQueue};
15
16#[derive(Debug)]
18pub struct ScheduleRequest<T> {
19 pub message: T,
20 pub run_at: Instant,
21}
22
23struct ScheduledEntry {
25 run_at: Instant,
26 queue_key: delay_queue::Key,
27}
28
29#[pin_project(project = SchedulerProj)]
30pub struct Scheduler<T, R> {
31 queue: DelayQueue<T>,
39 scheduled: HashMap<T, ScheduledEntry>,
43 pending: HashSet<T>,
45 #[pin]
47 requests: Fuse<R>,
48 debounce: Duration,
54}
55
56impl<T, R: Stream> Scheduler<T, R> {
57 fn new(requests: R, debounce: Duration) -> Self {
58 Self {
59 queue: DelayQueue::new(),
60 scheduled: HashMap::new(),
61 pending: HashSet::new(),
62 requests: requests.fuse(),
63 debounce,
64 }
65 }
66}
67
68impl<T: Hash + Eq + Clone, R> SchedulerProj<'_, T, R> {
69 fn schedule_message(&mut self, request: ScheduleRequest<T>) {
73 if self.pending.contains(&request.message) {
74 return;
76 }
77 let next_time = request
78 .run_at
79 .checked_add(*self.debounce)
80 .unwrap_or_else(far_future);
81 match self.scheduled.raw_entry_mut().from_key(&request.message) {
82 RawEntryMut::Occupied(mut old_entry) if old_entry.get().run_at >= request.run_at => {
86 let entry = old_entry.get_mut();
88 self.queue.reset_at(&entry.queue_key, next_time);
89 entry.run_at = next_time;
90 old_entry.insert_key(request.message);
91 }
92 RawEntryMut::Occupied(_old_entry) => {
93 }
95 RawEntryMut::Vacant(entry) => {
96 let message = request.message.clone();
98 entry.insert(request.message, ScheduledEntry {
99 run_at: next_time,
100 queue_key: self.queue.insert_at(message, next_time),
101 });
102 }
103 }
104 }
105
106 fn poll_pop_queue_message(
108 &mut self,
109 cx: &mut Context<'_>,
110 can_take_message: impl Fn(&T) -> bool,
111 ) -> Poll<T> {
112 if let Some(msg) = self.pending.iter().find(|msg| can_take_message(*msg)).cloned() {
113 return Poll::Ready(self.pending.take(&msg).unwrap());
114 }
115
116 loop {
117 match self.queue.poll_expired(cx) {
118 Poll::Ready(Some(msg)) => {
119 let msg = msg.into_inner();
120 let (msg, _) = self.scheduled.remove_entry(&msg).expect(
121 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
122 );
123 if can_take_message(&msg) {
124 break Poll::Ready(msg);
125 }
126 self.pending.insert(msg);
127 }
128 Poll::Ready(None) | Poll::Pending => break Poll::Pending,
129 }
130 }
131 }
132
133 pub fn pop_queue_message_into_pending(&mut self, cx: &mut Context<'_>) {
135 while let Poll::Ready(Some(msg)) = self.queue.poll_expired(cx) {
136 let msg = msg.into_inner();
137 self.scheduled.remove_entry(&msg).expect(
138 "Expired message was popped from the Scheduler queue, but was not in the metadata map",
139 );
140 self.pending.insert(msg);
141 }
142 }
143}
144
145pub struct Hold<'a, T, R> {
147 scheduler: Pin<&'a mut Scheduler<T, R>>,
148}
149
150impl<T, R> Stream for Hold<'_, T, R>
151where
152 T: Eq + Hash + Clone,
153 R: Stream<Item = ScheduleRequest<T>>,
154{
155 type Item = T;
156
157 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
158 let this = self.get_mut();
159 let mut scheduler = this.scheduler.as_mut().project();
160
161 loop {
162 match scheduler.requests.as_mut().poll_next(cx) {
163 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
164 Poll::Ready(None) => return Poll::Ready(None),
165 Poll::Pending => break,
166 }
167 }
168
169 scheduler.pop_queue_message_into_pending(cx);
170 Poll::Pending
171 }
172}
173
174pub struct HoldUnless<'a, T, R, C> {
176 scheduler: Pin<&'a mut Scheduler<T, R>>,
177 can_take_message: C,
178}
179
180impl<T, R, C> Stream for HoldUnless<'_, T, R, C>
181where
182 T: Eq + Hash + Clone,
183 R: Stream<Item = ScheduleRequest<T>>,
184 C: Fn(&T) -> bool + Unpin,
185{
186 type Item = T;
187
188 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
189 let this = self.get_mut();
190 let can_take_message = &this.can_take_message;
191 let mut scheduler = this.scheduler.as_mut().project();
192
193 loop {
194 match scheduler.requests.as_mut().poll_next(cx) {
195 Poll::Ready(Some(request)) => scheduler.schedule_message(request),
196 Poll::Ready(None) => return Poll::Ready(None),
197 Poll::Pending => break,
198 }
199 }
200
201 match scheduler.poll_pop_queue_message(cx, can_take_message) {
202 Poll::Ready(expired) => Poll::Ready(Some(expired)),
203 Poll::Pending => Poll::Pending,
204 }
205 }
206}
207
208impl<T, R> Scheduler<T, R>
209where
210 T: Eq + Hash + Clone,
211 R: Stream<Item = ScheduleRequest<T>>,
212{
213 pub fn hold_unless<C: Fn(&T) -> bool>(self: Pin<&mut Self>, can_take_message: C) -> HoldUnless<T, R, C> {
224 HoldUnless {
225 scheduler: self,
226 can_take_message,
227 }
228 }
229
230 #[must_use]
234 pub fn hold(self: Pin<&mut Self>) -> Hold<T, R> {
235 Hold { scheduler: self }
236 }
237
238 #[cfg(test)]
240 pub fn contains_pending(&self, msg: &T) -> bool {
241 self.pending.contains(msg)
242 }
243}
244
245impl<T, R> Stream for Scheduler<T, R>
246where
247 T: Eq + Hash + Clone,
248 R: Stream<Item = ScheduleRequest<T>>,
249{
250 type Item = T;
251
252 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
253 Pin::new(&mut self.hold_unless(|_| true)).poll_next(cx)
254 }
255}
256
257pub fn scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(requests: S) -> Scheduler<T, S> {
268 Scheduler::new(requests, Duration::ZERO)
269}
270
271#[allow(clippy::module_name_repetitions)]
279pub fn debounced_scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(
280 requests: S,
281 debounce: Duration,
282) -> Scheduler<T, S> {
283 Scheduler::new(requests, debounce)
284}
285
286pub(crate) fn far_future() -> Instant {
288 Instant::now() + Duration::from_secs(86400 * 365 * 30)
291}
292
293#[cfg(test)]
294mod tests {
295 use crate::utils::KubeRuntimeStreamExt;
296
297 use super::{debounced_scheduler, scheduler, ScheduleRequest};
298 use educe::Educe;
299 use futures::{channel::mpsc, future, poll, stream, FutureExt, SinkExt, StreamExt};
300 use std::{pin::pin, task::Poll};
301 use tokio::time::{advance, pause, sleep, Duration, Instant};
302
303 fn unwrap_poll<T>(poll: Poll<T>) -> T {
304 if let Poll::Ready(x) = poll {
305 x
306 } else {
307 panic!("Tried to unwrap a pending poll!")
308 }
309 }
310
311 #[derive(Educe, Eq, Clone, Debug)]
313 #[educe(PartialEq, Hash)]
314 struct SingletonMessage(#[educe(PartialEq(ignore), Hash(ignore))] u8);
315
316 #[tokio::test]
317 async fn scheduler_should_hold_and_release_items() {
318 pause();
319 let mut scheduler = Box::pin(scheduler(
320 stream::iter(vec![ScheduleRequest {
321 message: 1_u8,
322 run_at: Instant::now(),
323 }])
324 .on_complete(sleep(Duration::from_secs(4))),
325 ));
326 assert!(!scheduler.contains_pending(&1));
327 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
328 assert!(scheduler.contains_pending(&1));
329 assert_eq!(
330 unwrap_poll(poll!(scheduler.as_mut().hold_unless(|_| true).next())).unwrap(),
331 1_u8
332 );
333 assert!(!scheduler.contains_pending(&1));
334 assert!(scheduler.as_mut().hold_unless(|_| true).next().await.is_none());
335 }
336
337 #[tokio::test]
338 async fn scheduler_should_not_reschedule_pending_items() {
339 pause();
340 let (mut tx, rx) = mpsc::unbounded::<ScheduleRequest<u8>>();
341 let mut scheduler = Box::pin(scheduler(rx));
342 tx.send(ScheduleRequest {
343 message: 1,
344 run_at: Instant::now(),
345 })
346 .await
347 .unwrap();
348 assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
349 tx.send(ScheduleRequest {
350 message: 1,
351 run_at: Instant::now(),
352 })
353 .await
354 .unwrap();
355 future::join(
356 async {
357 sleep(Duration::from_secs(2)).await;
358 drop(tx);
359 },
360 async {
361 assert_eq!(scheduler.next().await.unwrap(), 1);
362 assert!(scheduler.next().await.is_none())
363 },
364 )
365 .await;
366 }
367
368 #[tokio::test]
369 async fn scheduler_pending_message_should_not_block_head_of_line() {
370 let mut scheduler = Box::pin(scheduler(
371 stream::iter(vec![
372 ScheduleRequest {
373 message: 1,
374 run_at: Instant::now(),
375 },
376 ScheduleRequest {
377 message: 2,
378 run_at: Instant::now(),
379 },
380 ])
381 .on_complete(sleep(Duration::from_secs(2))),
382 ));
383 assert_eq!(
384 scheduler.as_mut().hold_unless(|x| *x != 1).next().await.unwrap(),
385 2
386 );
387 }
388
389 #[tokio::test]
390 async fn scheduler_should_emit_items_as_requested() {
391 pause();
392 let mut scheduler = pin!(scheduler(
393 stream::iter(vec![
394 ScheduleRequest {
395 message: 1_u8,
396 run_at: Instant::now() + Duration::from_secs(1),
397 },
398 ScheduleRequest {
399 message: 2,
400 run_at: Instant::now() + Duration::from_secs(3),
401 },
402 ])
403 .on_complete(sleep(Duration::from_secs(5))),
404 ));
405 assert!(poll!(scheduler.next()).is_pending());
406 advance(Duration::from_secs(2)).await;
407 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 1);
408 assert!(poll!(scheduler.next()).is_pending());
409 advance(Duration::from_secs(2)).await;
410 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap(), 2);
411 assert!(scheduler.next().await.is_none());
413 }
414
415 #[tokio::test]
416 async fn scheduler_dedupe_should_keep_earlier_item() {
417 pause();
418 let mut scheduler = pin!(scheduler(
419 stream::iter(vec![
420 ScheduleRequest {
421 message: (),
422 run_at: Instant::now() + Duration::from_secs(1),
423 },
424 ScheduleRequest {
425 message: (),
426 run_at: Instant::now() + Duration::from_secs(3),
427 },
428 ])
429 .on_complete(sleep(Duration::from_secs(5))),
430 ));
431 assert!(poll!(scheduler.next()).is_pending());
432 advance(Duration::from_secs(2)).await;
433 scheduler.next().now_or_never().unwrap().unwrap();
434 assert!(scheduler.next().await.is_none());
436 }
437
438 #[tokio::test]
439 async fn scheduler_dedupe_should_replace_later_item() {
440 pause();
441 let mut scheduler = pin!(scheduler(
442 stream::iter(vec![
443 ScheduleRequest {
444 message: (),
445 run_at: Instant::now() + Duration::from_secs(3),
446 },
447 ScheduleRequest {
448 message: (),
449 run_at: Instant::now() + Duration::from_secs(1),
450 },
451 ])
452 .on_complete(sleep(Duration::from_secs(5))),
453 ));
454 assert!(poll!(scheduler.next()).is_pending());
455 advance(Duration::from_secs(2)).await;
456 scheduler.next().now_or_never().unwrap().unwrap();
457 assert!(scheduler.next().await.is_none());
459 }
460
461 #[tokio::test]
462 async fn scheduler_dedupe_should_allow_rescheduling_emitted_item() {
463 pause();
464 let (mut schedule_tx, schedule_rx) = mpsc::unbounded();
465 let mut scheduler = scheduler(schedule_rx);
466 schedule_tx
467 .send(ScheduleRequest {
468 message: (),
469 run_at: Instant::now() + Duration::from_secs(1),
470 })
471 .await
472 .unwrap();
473 assert!(poll!(scheduler.next()).is_pending());
474 advance(Duration::from_secs(2)).await;
475 scheduler.next().now_or_never().unwrap().unwrap();
476 assert!(poll!(scheduler.next()).is_pending());
477 schedule_tx
478 .send(ScheduleRequest {
479 message: (),
480 run_at: Instant::now() + Duration::from_secs(1),
481 })
482 .await
483 .unwrap();
484 assert!(poll!(scheduler.next()).is_pending());
485 advance(Duration::from_secs(2)).await;
486 scheduler.next().now_or_never().unwrap().unwrap();
487 assert!(poll!(scheduler.next()).is_pending());
488 }
489
490 #[tokio::test]
491 async fn scheduler_should_overwrite_message_with_soonest_version() {
492 pause();
493
494 let now = Instant::now();
495 let scheduler = scheduler(
496 stream::iter([
497 ScheduleRequest {
498 message: SingletonMessage(1),
499 run_at: now + Duration::from_secs(2),
500 },
501 ScheduleRequest {
502 message: SingletonMessage(2),
503 run_at: now + Duration::from_secs(1),
504 },
505 ])
506 .on_complete(sleep(Duration::from_secs(5))),
507 );
508 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![2]);
509 }
510
511 #[tokio::test]
512 async fn scheduler_should_not_overwrite_message_with_later_version() {
513 pause();
514
515 let now = Instant::now();
516 let scheduler = scheduler(
517 stream::iter([
518 ScheduleRequest {
519 message: SingletonMessage(1),
520 run_at: now + Duration::from_secs(1),
521 },
522 ScheduleRequest {
523 message: SingletonMessage(2),
524 run_at: now + Duration::from_secs(2),
525 },
526 ])
527 .on_complete(sleep(Duration::from_secs(5))),
528 );
529 assert_eq!(scheduler.map(|msg| msg.0).collect::<Vec<_>>().await, vec![1]);
530 }
531
532 #[tokio::test]
533 async fn scheduler_should_add_debounce_to_a_request() {
534 pause();
535
536 let now = Instant::now();
537 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
538 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(2));
539
540 sched_tx
541 .send(ScheduleRequest {
542 message: SingletonMessage(1),
543 run_at: now,
544 })
545 .await
546 .unwrap();
547 advance(Duration::from_secs(1)).await;
548 assert!(poll!(scheduler.next()).is_pending());
549 advance(Duration::from_secs(3)).await;
550 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 1);
551 }
552
553 #[tokio::test]
554 async fn scheduler_should_dedup_message_within_debounce_period() {
555 pause();
556
557 let mut now = Instant::now();
558 let (mut sched_tx, sched_rx) = mpsc::unbounded::<ScheduleRequest<SingletonMessage>>();
559 let mut scheduler = debounced_scheduler(sched_rx, Duration::from_secs(3));
560
561 sched_tx
562 .send(ScheduleRequest {
563 message: SingletonMessage(1),
564 run_at: now,
565 })
566 .await
567 .unwrap();
568 assert!(poll!(scheduler.next()).is_pending());
569 advance(Duration::from_secs(1)).await;
570
571 now = Instant::now();
572 sched_tx
573 .send(ScheduleRequest {
574 message: SingletonMessage(2),
575 run_at: now,
576 })
577 .await
578 .unwrap();
579 advance(Duration::from_millis(2500)).await;
581 assert!(poll!(scheduler.next()).is_pending());
582
583 advance(Duration::from_secs(3)).await;
584 assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().0, 2);
585 assert!(poll!(scheduler.next()).is_pending());
586 }
587}