fedimint_core/task/
jit.rs

1use std::convert::Infallible;
2use std::sync::Arc;
3use std::{fmt, panic};
4
5use fedimint_core::runtime::JoinHandle;
6use fedimint_logging::LOG_TASK;
7use futures::Future;
8use tokio::sync;
9use tracing::warn;
10
11use super::MaybeSend;
12
13pub type Jit<T> = JitCore<T, Infallible>;
14pub type JitTry<T, E> = JitCore<T, E>;
15pub type JitTryAnyhow<T> = JitCore<T, anyhow::Error>;
16
17/// Error that could have been returned before
18///
19/// Newtype over `Option<E>` that allows better user (error conversion mostly)
20/// experience
21#[derive(Debug)]
22pub enum OneTimeError<E> {
23    Original(E),
24    Copy(anyhow::Error),
25}
26
27impl<E> std::error::Error for OneTimeError<E>
28where
29    E: fmt::Debug + fmt::Display,
30{
31    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32        None
33    }
34
35    fn cause(&self) -> Option<&dyn std::error::Error> {
36        self.source()
37    }
38}
39
40impl<E> fmt::Display for OneTimeError<E>
41where
42    E: fmt::Display,
43{
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            Self::Original(o) => o.fmt(f),
47            Self::Copy(c) => c.fmt(f),
48        }
49    }
50}
51
52/// A value that initializes eagerly in parallel in a falliable way
53#[derive(Debug)]
54pub struct JitCore<T, E> {
55    inner: Arc<JitInner<T, E>>,
56}
57
58#[derive(Debug)]
59struct JitInner<T, E> {
60    handle: sync::Mutex<JoinHandle<Result<T, E>>>,
61    val: sync::OnceCell<Result<T, String>>,
62}
63
64impl<T, E> Clone for JitCore<T, E>
65where
66    T: Clone,
67{
68    fn clone(&self) -> Self {
69        Self {
70            inner: self.inner.clone(),
71        }
72    }
73}
74impl<T, E> Drop for JitInner<T, E> {
75    fn drop(&mut self) {
76        self.handle.get_mut().abort();
77    }
78}
79impl<T, E> JitCore<T, E>
80where
81    T: MaybeSend + 'static,
82    E: MaybeSend + 'static + fmt::Display,
83{
84    /// Create `JitTry` value, and spawn a future `f` that computes its value
85    ///
86    /// Unlike normal Rust futures, the `f` executes eagerly (is spawned as a
87    /// tokio task).
88    pub fn new_try<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self
89    where
90        Fut: Future<Output = std::result::Result<T, E>> + 'static + MaybeSend,
91    {
92        let handle = crate::runtime::spawn("jit-value", async { f().await });
93
94        Self {
95            inner: JitInner {
96                handle: handle.into(),
97                val: sync::OnceCell::new(),
98            }
99            .into(),
100        }
101    }
102
103    /// Get the reference to the value, potentially blocking for the
104    /// initialization future to complete
105    pub async fn get_try(&self) -> Result<&T, OneTimeError<E>> {
106        let mut init_error = None;
107        let value = self
108            .inner
109            .val
110            .get_or_init(|| async {
111                let handle: &mut _ = &mut *self.inner.handle.lock().await;
112                match handle.await {
113                        Ok(Ok(o)) => Ok(o),
114                        Ok(Err(err)) => {
115                            let err_str = err.to_string();
116                            init_error = Some(err);
117                            Err(err_str)
118                        },
119                        Err(err) => {
120
121                            #[cfg(not(target_family = "wasm"))]
122                            if err.is_panic() {
123                                warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value panicked");
124                                // Resume the panic on the main task
125                                panic::resume_unwind(err.into_panic());
126                            }
127                            #[cfg(not(target_family = "wasm"))]
128                            if err.is_cancelled() {
129                                warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value task canceled:");
130                            }
131                            Err(format!("Jit value {} failed unexpectedly with: {}", std::any::type_name::<T>(), err))
132                        },
133                    }
134            })
135            .await;
136        if let Some(err) = init_error {
137            return Err(OneTimeError::Original(err));
138        }
139        value
140            .as_ref()
141            .map_err(|err_str| OneTimeError::Copy(anyhow::Error::msg(err_str.to_owned())))
142    }
143}
144impl<T> JitCore<T, Infallible>
145where
146    T: MaybeSend + 'static,
147{
148    pub fn new<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self
149    where
150        Fut: Future<Output = T> + 'static + MaybeSend,
151        T: 'static,
152    {
153        Self::new_try(|| async { Ok(f().await) })
154    }
155
156    pub async fn get(&self) -> &T {
157        self.get_try().await.expect("can't fail")
158    }
159}
160#[cfg(test)]
161mod tests {
162    use std::time::Duration;
163
164    use anyhow::bail;
165
166    use super::{Jit, JitTry, JitTryAnyhow};
167
168    #[test_log::test(tokio::test)]
169    async fn sanity_jit() {
170        let v = Jit::new(|| async {
171            fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
172            3
173        });
174
175        assert_eq!(*v.get().await, 3);
176        assert_eq!(*v.get().await, 3);
177        assert_eq!(*v.clone().get().await, 3);
178    }
179
180    #[test_log::test(tokio::test)]
181    async fn sanity_jit_try_ok() {
182        let v = JitTryAnyhow::new_try(|| async {
183            fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
184            Ok(3)
185        });
186
187        assert_eq!(*v.get_try().await.expect("ok"), 3);
188        assert_eq!(*v.get_try().await.expect("ok"), 3);
189        assert_eq!(*v.clone().get_try().await.expect("ok"), 3);
190    }
191
192    #[test_log::test(tokio::test)]
193    async fn sanity_jit_try_err() {
194        let v = JitTry::new_try(|| async {
195            fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
196            bail!("BOOM");
197            #[allow(unreachable_code)]
198            Ok(3)
199        });
200
201        assert!(v.get_try().await.is_err());
202        assert!(v.get_try().await.is_err());
203        assert!(v.clone().get_try().await.is_err());
204    }
205}