use futures::{
stream::{Fuse, FusedStream},
Stream, StreamExt,
};
use pin_project::pin_project;
use snafu::{Backtrace, ResultExt, Snafu};
use std::{
collections::{hash_map::Entry, HashMap, HashSet},
hash::Hash,
pin::Pin,
task::{Context, Poll},
};
use tokio::time::{
self,
delay_queue::{self, DelayQueue},
Instant,
};
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("timer failure: {}", source))]
TimerError {
source: time::Error,
backtrace: Backtrace,
},
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub struct ScheduleRequest<T> {
pub message: T,
pub run_at: Instant,
}
struct ScheduledEntry {
run_at: Instant,
queue_key: delay_queue::Key,
}
#[pin_project(project = SchedulerProj)]
pub struct Scheduler<T, R> {
queue: DelayQueue<T>,
scheduled: HashMap<T, ScheduledEntry>,
pending: HashSet<T>,
#[pin]
requests: Fuse<R>,
}
impl<T, R: Stream> Scheduler<T, R> {
fn new(requests: R) -> Self {
Self {
queue: DelayQueue::new(),
scheduled: HashMap::new(),
pending: HashSet::new(),
requests: requests.fuse(),
}
}
}
impl<'a, T: Hash + Eq + Clone, R> SchedulerProj<'a, T, R> {
fn schedule_message(&mut self, request: ScheduleRequest<T>) {
if self.pending.contains(&request.message) {
return;
}
match self.scheduled.entry(request.message) {
Entry::Occupied(mut old_entry) if old_entry.get().run_at >= request.run_at => {
let entry = old_entry.get_mut();
self.queue.reset_at(&entry.queue_key, request.run_at);
entry.run_at = request.run_at;
}
Entry::Occupied(_old_entry) => {
}
Entry::Vacant(entry) => {
let message = entry.key().clone();
entry.insert(ScheduledEntry {
run_at: request.run_at,
queue_key: self.queue.insert_at(message, request.run_at),
});
}
}
}
fn poll_pop_queue_message(
&mut self,
cx: &mut Context<'_>,
can_take_message: impl Fn(&T) -> bool,
) -> Poll<Option<Result<T, time::Error>>> {
if let Some(msg) = self.pending.iter().find(|msg| can_take_message(*msg)).cloned() {
return Poll::Ready(Some(Ok(self.pending.take(&msg).unwrap())));
}
loop {
match self.queue.poll_expired(cx) {
Poll::Ready(Some(Ok(msg))) => {
let msg = msg.into_inner();
self.scheduled.remove(&msg).expect(
"Expired message was popped from the Scheduler queue, but was not in the metadata map",
);
if can_take_message(&msg) {
break Poll::Ready(Some(Ok(msg)));
} else {
self.pending.insert(msg);
}
}
Poll::Ready(Some(Err(err))) => break Poll::Ready(Some(Err(err))),
Poll::Ready(None) => {
break if self.pending.is_empty() {
Poll::Ready(None)
} else {
Poll::Pending
};
}
Poll::Pending => break Poll::Pending,
}
}
}
}
pub struct HoldUnless<'a, T, R, C> {
scheduler: Pin<&'a mut Scheduler<T, R>>,
can_take_message: C,
}
impl<'a, T, R, C> Stream for HoldUnless<'a, T, R, C>
where
T: Eq + Hash + Clone,
R: Stream<Item = ScheduleRequest<T>>,
C: Fn(&T) -> bool + Unpin,
{
type Item = Result<T>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let can_take_message = &this.can_take_message;
let mut scheduler = this.scheduler.as_mut().project();
while let Poll::Ready(Some(request)) = scheduler.requests.as_mut().poll_next(cx) {
scheduler.schedule_message(request);
}
match scheduler.poll_pop_queue_message(cx, &can_take_message) {
Poll::Ready(Some(expired)) => Poll::Ready(Some(expired.context(TimerError))),
Poll::Ready(None) => {
if scheduler.requests.is_terminated() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
Poll::Pending => Poll::Pending,
}
}
}
impl<T, R> Scheduler<T, R>
where
T: Eq + Hash + Clone,
R: Stream<Item = ScheduleRequest<T>>,
{
pub fn hold_unless<C: Fn(&T) -> bool>(self: Pin<&mut Self>, can_take_message: C) -> HoldUnless<T, R, C> {
HoldUnless {
scheduler: self,
can_take_message,
}
}
#[cfg(test)]
pub fn contains_pending(&self, msg: &T) -> bool {
self.pending.contains(msg)
}
}
impl<T, R> Stream for Scheduler<T, R>
where
T: Eq + Hash + Clone,
R: Stream<Item = ScheduleRequest<T>>,
{
type Item = Result<T>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.hold_unless(|_| true)).poll_next(cx)
}
}
pub fn scheduler<T: Eq + Hash + Clone, S: Stream<Item = ScheduleRequest<T>>>(requests: S) -> Scheduler<T, S> {
Scheduler::new(requests)
}
#[cfg(test)]
mod tests {
use super::{scheduler, ScheduleRequest};
use futures::{channel::mpsc, poll, stream, FutureExt, SinkExt, StreamExt};
use std::task::Poll;
use tokio::time::{advance, pause, Duration, Instant};
fn unwrap_poll<T>(poll: Poll<T>) -> T {
if let Poll::Ready(x) = poll {
x
} else {
panic!("Tried to unwrap a pending poll!")
}
}
#[tokio::test]
async fn scheduler_should_hold_and_release_items() {
pause();
let mut scheduler = Box::pin(scheduler(stream::iter(vec![ScheduleRequest {
message: 1_u8,
run_at: Instant::now(),
}])));
assert!(!scheduler.contains_pending(&1));
assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
assert!(scheduler.contains_pending(&1));
assert_eq!(
unwrap_poll(poll!(scheduler.as_mut().hold_unless(|_| true).next()))
.unwrap()
.unwrap(),
1_u8
);
assert!(!scheduler.contains_pending(&1));
assert!(scheduler.as_mut().hold_unless(|_| true).next().await.is_none());
}
#[tokio::test]
async fn scheduler_should_not_reschedule_pending_items() {
pause();
let (mut tx, rx) = mpsc::unbounded::<ScheduleRequest<u8>>();
let mut scheduler = Box::pin(scheduler(rx));
tx.send(ScheduleRequest {
message: 1,
run_at: Instant::now(),
})
.await
.unwrap();
assert!(poll!(scheduler.as_mut().hold_unless(|_| false).next()).is_pending());
tx.send(ScheduleRequest {
message: 1,
run_at: Instant::now(),
})
.await
.unwrap();
drop(tx);
assert_eq!(scheduler.next().await.unwrap().unwrap(), 1);
assert!(scheduler.next().await.is_none());
}
#[tokio::test]
async fn scheduler_pending_message_should_not_block_head_of_line() {
let mut scheduler = Box::pin(scheduler(stream::iter(vec![
ScheduleRequest {
message: 1,
run_at: Instant::now(),
},
ScheduleRequest {
message: 2,
run_at: Instant::now(),
},
])));
assert_eq!(
scheduler
.as_mut()
.hold_unless(|x| *x != 1)
.next()
.await
.unwrap()
.unwrap(),
2
);
}
#[tokio::test]
async fn scheduler_should_emit_items_as_requested() {
pause();
let mut scheduler = scheduler(stream::iter(vec![
ScheduleRequest {
message: 1_u8,
run_at: Instant::now() + Duration::from_secs(1),
},
ScheduleRequest {
message: 2,
run_at: Instant::now() + Duration::from_secs(3),
},
]));
assert!(poll!(scheduler.next()).is_pending());
advance(Duration::from_secs(2)).await;
assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().unwrap(), 1);
assert!(poll!(scheduler.next()).is_pending());
advance(Duration::from_secs(2)).await;
assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().unwrap(), 2);
assert!(scheduler.next().await.is_none());
}
#[tokio::test]
async fn scheduler_dedupe_should_keep_earlier_item() {
pause();
let mut scheduler = scheduler(stream::iter(vec![
ScheduleRequest {
message: (),
run_at: Instant::now() + Duration::from_secs(1),
},
ScheduleRequest {
message: (),
run_at: Instant::now() + Duration::from_secs(3),
},
]));
assert!(poll!(scheduler.next()).is_pending());
advance(Duration::from_secs(2)).await;
assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().unwrap(), ());
assert!(scheduler.next().await.is_none());
}
#[tokio::test]
async fn scheduler_dedupe_should_replace_later_item() {
pause();
let mut scheduler = scheduler(stream::iter(vec![
ScheduleRequest {
message: (),
run_at: Instant::now() + Duration::from_secs(3),
},
ScheduleRequest {
message: (),
run_at: Instant::now() + Duration::from_secs(1),
},
]));
assert!(poll!(scheduler.next()).is_pending());
advance(Duration::from_secs(2)).await;
assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().unwrap(), ());
assert!(scheduler.next().await.is_none());
}
#[tokio::test]
async fn scheduler_dedupe_should_allow_rescheduling_emitted_item() {
pause();
let (mut schedule_tx, schedule_rx) = mpsc::unbounded();
let mut scheduler = scheduler(schedule_rx);
schedule_tx
.send(ScheduleRequest {
message: (),
run_at: Instant::now() + Duration::from_secs(1),
})
.await
.unwrap();
assert!(poll!(scheduler.next()).is_pending());
advance(Duration::from_secs(2)).await;
assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().unwrap(), ());
assert!(poll!(scheduler.next()).is_pending());
schedule_tx
.send(ScheduleRequest {
message: (),
run_at: Instant::now() + Duration::from_secs(1),
})
.await
.unwrap();
assert!(poll!(scheduler.next()).is_pending());
advance(Duration::from_secs(2)).await;
assert_eq!(scheduler.next().now_or_never().unwrap().unwrap().unwrap(), ());
assert!(poll!(scheduler.next()).is_pending());
}
}