use cooked_waker::IntoWaker;
use cooked_waker::ViaRawPointer;
use cooked_waker::Wake;
use cooked_waker::WakeRef;
use futures::Future;
use std::cell::Cell;
use std::cell::Ref;
use std::cell::RefCell;
use std::cell::UnsafeCell;
use std::collections::btree_set;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::mem::MaybeUninit;
use std::num::NonZeroU64;
use std::pin::Pin;
use std::task::ready;
use std::task::Context;
use std::task::Poll;
use std::task::Waker;
use std::time::Duration;
use tokio::time::Instant;
use tokio::time::Sleep;
pub(crate) type WebTimerId = u64;
#[derive(PartialEq, Eq, PartialOrd, Ord)]
enum TimerType {
Repeat(NonZeroU64),
Once,
}
#[derive(PartialEq, Eq, PartialOrd, Ord)]
struct TimerKey(Instant, u64, TimerType, bool);
struct TimerData<T> {
data: T,
unrefd: bool,
#[cfg(any(windows, test))]
high_res: bool,
#[cfg(not(any(windows, test)))]
high_res: (),
}
pub(crate) struct WebTimers<T> {
next_id: Cell<WebTimerId>,
timers: RefCell<BTreeSet<TimerKey>>,
data_map: RefCell<BTreeMap<WebTimerId, TimerData<T>>>,
unrefd_count: Cell<usize>,
sleep: Box<MutableSleep>,
high_res_timer_lock: HighResTimerLock,
}
impl<T> Default for WebTimers<T> {
fn default() -> Self {
Self {
next_id: Default::default(),
timers: Default::default(),
data_map: Default::default(),
unrefd_count: Default::default(),
sleep: MutableSleep::new(),
high_res_timer_lock: Default::default(),
}
}
}
pub(crate) struct WebTimersIterator<'a, T> {
data: Ref<'a, BTreeMap<WebTimerId, TimerData<T>>>,
timers: Ref<'a, BTreeSet<TimerKey>>,
}
impl<'a, T> IntoIterator for &'a WebTimersIterator<'a, T> {
type IntoIter = WebTimersIteratorImpl<'a, T>;
type Item = (u64, bool, bool);
fn into_iter(self) -> Self::IntoIter {
WebTimersIteratorImpl {
data: &self.data,
timers: self.timers.iter(),
}
}
}
pub(crate) struct WebTimersIteratorImpl<'a, T> {
data: &'a BTreeMap<WebTimerId, TimerData<T>>,
timers: btree_set::Iter<'a, TimerKey>,
}
impl<'a, T> Iterator for WebTimersIteratorImpl<'a, T> {
type Item = (u64, bool, bool);
fn next(&mut self) -> Option<Self::Item> {
loop {
let item = self.timers.next()?;
if self.data.contains_key(&item.1) {
return Some((item.1, !matches!(item.2, TimerType::Once), item.3));
}
}
}
}
struct MutableSleep {
sleep: UnsafeCell<Option<Sleep>>,
ready: Cell<bool>,
external_waker: UnsafeCell<Option<Waker>>,
internal_waker: Waker,
}
#[allow(clippy::borrowed_box)]
impl MutableSleep {
fn new() -> Box<Self> {
let mut new = Box::new(MaybeUninit::<Self>::uninit());
new.write(MutableSleep {
sleep: Default::default(),
ready: Cell::default(),
external_waker: UnsafeCell::default(),
internal_waker: MutableSleepWaker {
inner: new.as_ptr(),
}
.into_waker(),
});
unsafe { std::mem::transmute(new) }
}
fn poll_ready(self: &Box<Self>, cx: &mut Context) -> Poll<()> {
if self.ready.take() {
Poll::Ready(())
} else {
let external =
unsafe { self.external_waker.get().as_mut().unwrap_unchecked() };
if let Some(external) = external {
let waker = cx.waker();
if !external.will_wake(waker) {
external.clone_from(waker);
}
Poll::Pending
} else {
*external = Some(cx.waker().clone());
Poll::Pending
}
}
}
fn clear(self: &Box<Self>) {
unsafe {
*self.sleep.get() = None;
}
}
fn change(self: &Box<Self>, instant: Instant) {
let pin = unsafe {
*self.sleep.get() = Some(tokio::time::sleep_until(instant));
Pin::new_unchecked(
self
.sleep
.get()
.as_mut()
.unwrap_unchecked()
.as_mut()
.unwrap_unchecked(),
)
};
let waker = &self.internal_waker;
if pin.poll(&mut Context::from_waker(waker)).is_ready() {
self.ready.set(true);
self.internal_waker.wake_by_ref();
}
}
}
#[repr(transparent)]
#[derive(Clone)]
struct MutableSleepWaker {
inner: *const MutableSleep,
}
unsafe impl Send for MutableSleepWaker {}
unsafe impl Sync for MutableSleepWaker {}
impl WakeRef for MutableSleepWaker {
fn wake_by_ref(&self) {
unsafe {
let this = self.inner.as_ref().unwrap_unchecked();
this.ready.set(true);
let waker = this.external_waker.get().as_mut().unwrap_unchecked();
if let Some(waker) = waker.as_ref() {
waker.wake_by_ref();
}
}
}
}
impl Wake for MutableSleepWaker {
fn wake(self) {
self.wake_by_ref()
}
}
impl Drop for MutableSleepWaker {
fn drop(&mut self) {}
}
unsafe impl ViaRawPointer for MutableSleepWaker {
type Target = ();
fn into_raw(self) -> *mut () {
self.inner as _
}
unsafe fn from_raw(ptr: *mut ()) -> Self {
MutableSleepWaker { inner: ptr as _ }
}
}
impl<T: Clone> WebTimers<T> {
pub(crate) fn iter(&self) -> WebTimersIterator<T> {
WebTimersIterator {
data: self.data_map.borrow(),
timers: self.timers.borrow(),
}
}
pub fn ref_timer(&self, id: WebTimerId) {
if let Some(TimerData { unrefd, .. }) =
self.data_map.borrow_mut().get_mut(&id)
{
if std::mem::replace(unrefd, false) {
self.unrefd_count.set(self.unrefd_count.get() - 1);
}
}
}
pub fn unref_timer(&self, id: WebTimerId) {
if let Some(TimerData { unrefd, .. }) =
self.data_map.borrow_mut().get_mut(&id)
{
if !std::mem::replace(unrefd, true) {
self.unrefd_count.set(self.unrefd_count.get() + 1);
}
}
}
pub fn queue_timer(&self, timeout_ms: u64, data: T) -> WebTimerId {
self.queue_timer_internal(false, timeout_ms, data, false)
}
pub fn queue_timer_repeat(&self, timeout_ms: u64, data: T) -> WebTimerId {
self.queue_timer_internal(true, timeout_ms, data, false)
}
pub fn queue_system_timer(
&self,
repeat: bool,
timeout_ms: u64,
data: T,
) -> WebTimerId {
self.queue_timer_internal(repeat, timeout_ms, data, true)
}
fn queue_timer_internal(
&self,
repeat: bool,
timeout_ms: u64,
data: T,
is_system_timer: bool,
) -> WebTimerId {
#[allow(clippy::let_unit_value)]
let high_res = self.high_res_timer_lock.maybe_lock(timeout_ms);
let id = self.next_id.get() + 1;
self.next_id.set(id);
let mut timers = self.timers.borrow_mut();
let deadline = Instant::now()
.checked_add(Duration::from_millis(timeout_ms))
.unwrap();
if let Some(TimerKey(k, ..)) = timers.first() {
if &deadline < k {
self.sleep.change(deadline);
}
} else {
self.sleep.change(deadline);
}
let timer_type = if repeat {
TimerType::Repeat(
NonZeroU64::new(timeout_ms).unwrap_or(NonZeroU64::new(1).unwrap()),
)
} else {
TimerType::Once
};
timers.insert(TimerKey(deadline, id, timer_type, is_system_timer));
let mut data_map = self.data_map.borrow_mut();
data_map.insert(
id,
TimerData {
data,
unrefd: false,
high_res,
},
);
id
}
pub fn cancel_timer(&self, timer: u64) -> Option<T> {
let mut data_map = self.data_map.borrow_mut();
if let Some(TimerData {
data,
unrefd,
high_res,
}) = data_map.remove(&timer)
{
if data_map.is_empty() {
debug_assert_eq!(self.unrefd_count.get(), if unrefd { 1 } else { 0 });
#[cfg(any(windows, test))]
debug_assert_eq!(self.high_res_timer_lock.is_locked(), high_res);
self.high_res_timer_lock.clear();
self.unrefd_count.set(0);
self.timers.borrow_mut().clear();
self.sleep.clear();
} else {
self.high_res_timer_lock.maybe_unlock(high_res);
if unrefd {
self.unrefd_count.set(self.unrefd_count.get() - 1);
}
}
Some(data)
} else {
None
}
}
pub fn poll_timers(&self, cx: &mut Context) -> Poll<Vec<(u64, T)>> {
ready!(self.sleep.poll_ready(cx));
let now = Instant::now();
let mut timers = self.timers.borrow_mut();
let mut data = self.data_map.borrow_mut();
let mut output = vec![];
let mut split = timers.split_off(&TimerKey(now, 0, TimerType::Once, false));
std::mem::swap(&mut split, &mut timers);
for TimerKey(_, id, timer_type, is_system_timer) in split {
if let TimerType::Repeat(interval) = timer_type {
if let Some(TimerData { data, .. }) = data.get(&id) {
output.push((id, data.clone()));
timers.insert(TimerKey(
now
.checked_add(Duration::from_millis(interval.into()))
.unwrap(),
id,
timer_type,
is_system_timer,
));
}
} else if let Some(TimerData {
data,
unrefd,
high_res,
}) = data.remove(&id)
{
self.high_res_timer_lock.maybe_unlock(high_res);
if unrefd {
self.unrefd_count.set(self.unrefd_count.get() - 1);
}
output.push((id, data));
}
}
if output.is_empty() {
debug_assert!(!data.is_empty());
while let Some(TimerKey(_, id, ..)) = timers.first() {
if data.contains_key(id) {
break;
} else {
timers.pop_first();
}
}
if let Some(TimerKey(k, ..)) = timers.first() {
self.sleep.change(*k);
}
return Poll::Pending;
}
if data.is_empty() {
if !timers.is_empty() {
timers.clear();
self.sleep.clear();
}
} else {
const COMPACTION_MINIMUM: usize = 16;
let tombstone_count = timers.len() - data.len();
if tombstone_count > data.len() && tombstone_count > COMPACTION_MINIMUM {
timers.retain(|k| data.contains_key(&k.1));
}
if let Some(TimerKey(k, ..)) = timers.first() {
self.sleep.change(*k);
}
}
Poll::Ready(output)
}
pub fn is_empty(&self) -> bool {
self.data_map.borrow().is_empty()
}
pub fn len(&self) -> usize {
self.data_map.borrow().len()
}
pub fn unref_len(&self) -> usize {
self.unrefd_count.get()
}
#[cfg(test)]
pub fn assert_consistent(&self) {
if self.data_map.borrow().is_empty() {
assert_eq!(self.timers.borrow().len(), 0);
assert_eq!(self.unrefd_count.get(), 0);
assert!(!self.high_res_timer_lock.is_locked());
} else {
assert!(self.unrefd_count.get() <= self.data_map.borrow().len());
assert!(self.high_res_timer_lock.lock_count.get() <= self.len());
}
}
pub fn has_pending_timers(&self) -> bool {
self.len() > self.unref_len()
}
}
#[cfg(windows)]
#[link(name = "winmm")]
extern "C" {
fn timeBeginPeriod(n: u32);
fn timeEndPeriod(n: u32);
}
#[derive(Default)]
struct HighResTimerLock {
#[cfg(any(windows, test))]
lock_count: Cell<usize>,
}
impl HighResTimerLock {
#[cfg(any(windows, test))]
const LOW_RES_TIMER_RESOLUTION: u64 = 100;
#[cfg(any(windows, test))]
#[inline(always)]
fn maybe_unlock(&self, high_res: bool) {
if high_res {
let old = self.lock_count.get();
debug_assert!(old > 0);
let new = old - 1;
self.lock_count.set(new);
#[cfg(windows)]
if new == 0 {
unsafe {
timeEndPeriod(1);
}
}
}
}
#[cfg(not(any(windows, test)))]
#[inline(always)]
fn maybe_unlock(&self, _high_res: ()) {}
#[cfg(any(windows, test))]
#[inline(always)]
fn maybe_lock(&self, timeout_ms: u64) -> bool {
if timeout_ms <= Self::LOW_RES_TIMER_RESOLUTION {
let old = self.lock_count.get();
#[cfg(windows)]
if old == 0 {
unsafe {
timeBeginPeriod(1);
}
}
self.lock_count.set(old + 1);
true
} else {
false
}
}
#[cfg(not(any(windows, test)))]
#[inline(always)]
fn maybe_lock(&self, _timeout_ms: u64) {}
#[cfg(any(windows, test))]
#[inline(always)]
fn clear(&self) {
#[cfg(windows)]
if self.lock_count.get() > 0 {
unsafe {
timeEndPeriod(1);
}
}
self.lock_count.set(0);
}
#[cfg(not(any(windows, test)))]
#[inline(always)]
fn clear(&self) {}
#[cfg(any(windows, test))]
fn is_locked(&self) -> bool {
self.lock_count.get() > 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::future::poll_fn;
use rstest::rstest;
const TEN_THOUSAND: u64 = if cfg!(miri) { 100 } else { 10_000 };
fn async_test<F: Future<Output = T>, T>(f: F) -> T {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.unwrap();
runtime.block_on(f)
}
async fn poll_all(timers: &WebTimers<()>) -> Vec<(u64, ())> {
timers.assert_consistent();
let len = timers.len();
let mut v = vec![];
while !timers.is_empty() {
let mut batch = poll_fn(|cx| {
timers.assert_consistent();
timers.poll_timers(cx)
})
.await;
v.append(&mut batch);
#[allow(clippy::print_stderr)]
{
eprintln!(
"{} ({} {})",
v.len(),
timers.len(),
timers.data_map.borrow().len(),
);
}
timers.assert_consistent();
}
assert_eq!(v.len(), len);
v
}
#[test]
fn test_timer() {
async_test(async {
let timers = WebTimers::<()>::default();
let _a = timers.queue_timer(1, ());
let v = poll_all(&timers).await;
assert_eq!(v.len(), 1);
});
}
#[test]
fn test_high_res_lock() {
async_test(async {
let timers = WebTimers::<()>::default();
assert!(!timers.high_res_timer_lock.is_locked());
let _a = timers.queue_timer(1, ());
assert!(timers.high_res_timer_lock.is_locked());
let v = poll_all(&timers).await;
assert_eq!(v.len(), 1);
assert!(!timers.high_res_timer_lock.is_locked());
});
}
#[rstest]
#[test]
fn test_timer_cancel_1(#[values(0, 1, 2, 3)] which: u64) {
async_test(async {
let timers = WebTimers::<()>::default();
for i in 0..4 {
let id = timers.queue_timer(i * 25, ());
if i == which {
assert!(timers.cancel_timer(id).is_some());
}
}
assert_eq!(timers.len(), 3);
let v = poll_all(&timers).await;
assert_eq!(v.len(), 3);
})
}
#[rstest]
#[test]
fn test_timer_cancel_2(#[values(0, 1, 2)] which: u64) {
async_test(async {
let timers = WebTimers::<()>::default();
for i in 0..4 {
let id = timers.queue_timer(i * 25, ());
if i == which || i == which + 1 {
assert!(timers.cancel_timer(id).is_some());
}
}
assert_eq!(timers.len(), 2);
let v = poll_all(&timers).await;
assert_eq!(v.len(), 2);
})
}
#[test]
fn test_timers_10_random() {
async_test(async {
let timers = WebTimers::<()>::default();
for i in 0..10 {
timers.queue_timer((i % 3) * 10, ());
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), 10);
})
}
#[test]
fn test_timers_10_random_cancel() {
async_test(async {
let timers = WebTimers::<()>::default();
for i in 0..10 {
let id = timers.queue_timer((i % 3) * 10, ());
timers.cancel_timer(id);
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), 0);
});
}
#[rstest]
#[test]
fn test_timers_10_random_cancel_after(#[values(true, false)] reverse: bool) {
async_test(async {
let timers = WebTimers::<()>::default();
let mut ids = vec![];
for i in 0..2 {
ids.push(timers.queue_timer((i % 3) * 10, ()));
}
if reverse {
ids.reverse();
}
for id in ids {
timers.cancel_timer(id);
timers.assert_consistent();
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), 0);
});
}
#[test]
fn test_timers_10() {
async_test(async {
let timers = WebTimers::<()>::default();
for _i in 0..10 {
timers.queue_timer(1, ());
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), 10);
});
}
#[test]
fn test_timers_10_000_random() {
async_test(async {
let timers = WebTimers::<()>::default();
for i in 0..TEN_THOUSAND {
timers.queue_timer(i % 10, ());
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), TEN_THOUSAND as usize);
});
}
#[test]
fn test_timers_cancel_first() {
async_test(async {
let timers = WebTimers::<()>::default();
let mut ids = vec![];
for _ in 0..TEN_THOUSAND {
ids.push(timers.queue_timer(1, ()));
}
for i in 0..10 {
timers.queue_timer(i * 25, ());
}
for id in ids {
timers.cancel_timer(id);
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), 10);
});
}
#[test]
fn test_timers_10_000_cancel_most() {
async_test(async {
let timers = WebTimers::<()>::default();
let mut ids = vec![];
for i in 0..TEN_THOUSAND {
ids.push(timers.queue_timer(i % 100, ()));
}
fastrand::seed(42);
ids.retain(|_| fastrand::u8(0..10) > 0);
for id in ids.iter() {
timers.cancel_timer(*id);
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), TEN_THOUSAND as usize - ids.len());
});
}
#[rstest]
#[test]
fn test_chaos(#[values(42, 99, 1000)] seed: u64) {
async_test(async {
let timers = WebTimers::<()>::default();
fastrand::seed(seed);
let mut count = 0;
let mut ref_count = 0;
for _ in 0..TEN_THOUSAND {
let mut cancelled = false;
let mut unrefed = false;
let id = timers.queue_timer(fastrand::u64(0..10), ());
for _ in 0..fastrand::u64(0..10) {
if fastrand::u8(0..10) == 0 {
timers.cancel_timer(id);
cancelled = true;
}
if fastrand::u8(0..10) == 0 {
timers.ref_timer(id);
unrefed = false;
}
if fastrand::u8(0..10) == 0 {
timers.unref_timer(id);
unrefed = true;
}
}
if !cancelled {
count += 1;
}
if !unrefed {
ref_count += 1;
}
timers.assert_consistent();
}
#[allow(clippy::print_stderr)]
{
eprintln!("count={count} ref_count={ref_count}");
}
let v = poll_all(&timers).await;
assert_eq!(v.len(), count);
assert!(timers.is_empty());
assert!(!timers.has_pending_timers());
});
}
}