fedimint_core/
runtime.rs

1//! Copyright 2021 The Matrix.org Foundation C.I.C.
2//! Abstraction over an executor so we can spawn tasks under WASM the same way
3//! we do usually.
4
5// Adapted from https://github.com/matrix-org/matrix-rust-sdk
6
7use std::future::Future;
8use std::time::Duration;
9
10use fedimint_logging::LOG_RUNTIME;
11use thiserror::Error;
12use tokio::time::Instant;
13use tracing::Instrument;
14
15#[derive(Debug, Error)]
16#[error("deadline has elapsed")]
17pub struct Elapsed;
18
19pub use self::r#impl::*;
20
21#[cfg(not(target_family = "wasm"))]
22mod r#impl {
23    pub use tokio::task::{JoinError, JoinHandle};
24
25    use super::{Duration, Elapsed, Future, Instant, Instrument, LOG_RUNTIME};
26
27    pub fn spawn<F, T>(name: &str, future: F) -> tokio::task::JoinHandle<T>
28    where
29        F: Future<Output = T> + 'static + Send,
30        T: Send + 'static,
31    {
32        let span = tracing::debug_span!(target: LOG_RUNTIME, parent: None, "spawn", task = name);
33        // nosemgrep: ban-tokio-spawn
34        tokio::spawn(future.instrument(span))
35    }
36
37    // note: this call does not exist on wasm and you need to handle it
38    // conditionally at the call site of packages that compile on wasm
39    pub fn block_in_place<F, R>(f: F) -> R
40    where
41        F: FnOnce() -> R,
42    {
43        // nosemgrep: ban-raw-block-in-place
44        tokio::task::block_in_place(f)
45    }
46
47    // note: this call does not exist on wasm and you need to handle it
48    // conditionally at the call site of packages that compile on wasm
49    pub fn block_on<F: Future>(future: F) -> F::Output {
50        // nosemgrep: ban-raw-block-on
51        tokio::runtime::Handle::current().block_on(future)
52    }
53
54    pub async fn sleep(duration: Duration) {
55        // nosemgrep: ban-tokio-sleep
56        tokio::time::sleep(duration).await;
57    }
58
59    pub async fn sleep_until(deadline: Instant) {
60        tokio::time::sleep_until(deadline).await;
61    }
62
63    pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
64    where
65        T: Future,
66    {
67        tokio::time::timeout(duration, future)
68            .await
69            .map_err(|_| Elapsed)
70    }
71}
72
73#[cfg(target_family = "wasm")]
74mod r#impl {
75
76    pub use std::convert::Infallible as JoinError;
77    use std::pin::Pin;
78    use std::task::{Context, Poll};
79
80    use async_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
81    use futures_util::future::RemoteHandle;
82    use futures_util::FutureExt;
83
84    use super::*;
85
86    #[derive(Debug)]
87    pub struct JoinHandle<T> {
88        handle: Option<RemoteHandle<T>>,
89    }
90
91    impl<T> JoinHandle<T> {
92        pub fn abort(&mut self) {
93            drop(self.handle.take());
94        }
95    }
96
97    impl<T> Drop for JoinHandle<T> {
98        fn drop(&mut self) {
99            // don't abort the spawned future
100            if let Some(h) = self.handle.take() {
101                h.forget();
102            }
103        }
104    }
105    impl<T: 'static> Future for JoinHandle<T> {
106        type Output = Result<T, JoinError>;
107
108        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
109            if let Some(handle) = self.handle.as_mut() {
110                Pin::new(handle).poll(cx).map(Ok)
111            } else {
112                Poll::Pending
113            }
114        }
115    }
116
117    pub fn spawn<F, T>(name: &str, future: F) -> JoinHandle<T>
118    where
119        F: Future<Output = T> + 'static,
120    {
121        let span = tracing::debug_span!(target: LOG_RUNTIME, "spawn", task = name);
122        let (fut, handle) = future.remote_handle();
123        wasm_bindgen_futures::spawn_local(fut);
124
125        JoinHandle {
126            handle: Some(handle),
127        }
128    }
129
130    pub(crate) fn spawn_local<F>(name: &str, future: F) -> JoinHandle<()>
131    where
132        // No Send needed on wasm
133        F: Future<Output = ()> + 'static,
134    {
135        spawn(name, future)
136    }
137
138    pub async fn sleep(duration: Duration) {
139        gloo_timers::future::sleep(duration.min(Duration::from_millis(i32::MAX as _))).await
140    }
141
142    pub async fn sleep_until(deadline: Instant) {
143        // nosemgrep: ban-system-time-now
144        // nosemgrep: ban-instant-now
145        sleep(deadline.saturating_duration_since(Instant::now())).await
146    }
147
148    pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
149    where
150        T: Future,
151    {
152        futures::pin_mut!(future);
153        futures::select_biased! {
154            value = future.fuse() => Ok(value),
155            _ = sleep(duration).fuse() => Err(Elapsed),
156        }
157    }
158}