1use std::future::Future;
8use std::time::Duration;
9
10use fedimint_logging::LOG_RUNTIME;
11use thiserror::Error;
12use tokio::time::Instant;
13use tracing::Instrument;
14
15#[derive(Debug, Error)]
16#[error("deadline has elapsed")]
17pub struct Elapsed;
18
19pub use self::r#impl::*;
20
21#[cfg(not(target_family = "wasm"))]
22mod r#impl {
23 pub use tokio::task::{JoinError, JoinHandle};
24
25 use super::{Duration, Elapsed, Future, Instant, Instrument, LOG_RUNTIME};
26
27 pub fn spawn<F, T>(name: &str, future: F) -> tokio::task::JoinHandle<T>
28 where
29 F: Future<Output = T> + 'static + Send,
30 T: Send + 'static,
31 {
32 let span = tracing::debug_span!(target: LOG_RUNTIME, parent: None, "spawn", task = name);
33 tokio::spawn(future.instrument(span))
35 }
36
37 pub fn block_in_place<F, R>(f: F) -> R
40 where
41 F: FnOnce() -> R,
42 {
43 tokio::task::block_in_place(f)
45 }
46
47 pub fn block_on<F: Future>(future: F) -> F::Output {
50 tokio::runtime::Handle::current().block_on(future)
52 }
53
54 pub async fn sleep(duration: Duration) {
55 tokio::time::sleep(duration).await;
57 }
58
59 pub async fn sleep_until(deadline: Instant) {
60 tokio::time::sleep_until(deadline).await;
61 }
62
63 pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
64 where
65 T: Future,
66 {
67 tokio::time::timeout(duration, future)
68 .await
69 .map_err(|_| Elapsed)
70 }
71}
72
73#[cfg(target_family = "wasm")]
74mod r#impl {
75
76 pub use std::convert::Infallible as JoinError;
77 use std::pin::Pin;
78 use std::task::{Context, Poll};
79
80 use async_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
81 use futures_util::future::RemoteHandle;
82 use futures_util::FutureExt;
83
84 use super::*;
85
86 #[derive(Debug)]
87 pub struct JoinHandle<T> {
88 handle: Option<RemoteHandle<T>>,
89 }
90
91 impl<T> JoinHandle<T> {
92 pub fn abort(&mut self) {
93 drop(self.handle.take());
94 }
95 }
96
97 impl<T> Drop for JoinHandle<T> {
98 fn drop(&mut self) {
99 if let Some(h) = self.handle.take() {
101 h.forget();
102 }
103 }
104 }
105 impl<T: 'static> Future for JoinHandle<T> {
106 type Output = Result<T, JoinError>;
107
108 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109 if let Some(handle) = self.handle.as_mut() {
110 Pin::new(handle).poll(cx).map(Ok)
111 } else {
112 Poll::Pending
113 }
114 }
115 }
116
117 pub fn spawn<F, T>(name: &str, future: F) -> JoinHandle<T>
118 where
119 F: Future<Output = T> + 'static,
120 {
121 let span = tracing::debug_span!(target: LOG_RUNTIME, "spawn", task = name);
122 let (fut, handle) = future.remote_handle();
123 wasm_bindgen_futures::spawn_local(fut);
124
125 JoinHandle {
126 handle: Some(handle),
127 }
128 }
129
130 pub(crate) fn spawn_local<F>(name: &str, future: F) -> JoinHandle<()>
131 where
132 F: Future<Output = ()> + 'static,
134 {
135 spawn(name, future)
136 }
137
138 pub async fn sleep(duration: Duration) {
139 gloo_timers::future::sleep(duration.min(Duration::from_millis(i32::MAX as _))).await
140 }
141
142 pub async fn sleep_until(deadline: Instant) {
143 sleep(deadline.saturating_duration_since(Instant::now())).await
146 }
147
148 pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
149 where
150 T: Future,
151 {
152 futures::pin_mut!(future);
153 futures::select_biased! {
154 value = future.fuse() => Ok(value),
155 _ = sleep(duration).fuse() => Err(Elapsed),
156 }
157 }
158}