fedimint_core/task/
jit.rs1use std::convert::Infallible;
2use std::sync::Arc;
3use std::{fmt, panic};
4
5use fedimint_core::runtime::JoinHandle;
6use fedimint_logging::LOG_TASK;
7use futures::Future;
8use tokio::sync;
9use tracing::warn;
10
11use super::MaybeSend;
12
13pub type Jit<T> = JitCore<T, Infallible>;
14pub type JitTry<T, E> = JitCore<T, E>;
15pub type JitTryAnyhow<T> = JitCore<T, anyhow::Error>;
16
17#[derive(Debug)]
22pub enum OneTimeError<E> {
23 Original(E),
24 Copy(anyhow::Error),
25}
26
27impl<E> std::error::Error for OneTimeError<E>
28where
29 E: fmt::Debug + fmt::Display,
30{
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 None
33 }
34
35 fn cause(&self) -> Option<&dyn std::error::Error> {
36 self.source()
37 }
38}
39
40impl<E> fmt::Display for OneTimeError<E>
41where
42 E: fmt::Display,
43{
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Self::Original(o) => o.fmt(f),
47 Self::Copy(c) => c.fmt(f),
48 }
49 }
50}
51
52#[derive(Debug)]
54pub struct JitCore<T, E> {
55 inner: Arc<JitInner<T, E>>,
56}
57
58#[derive(Debug)]
59struct JitInner<T, E> {
60 handle: sync::Mutex<JoinHandle<Result<T, E>>>,
61 val: sync::OnceCell<Result<T, String>>,
62}
63
64impl<T, E> Clone for JitCore<T, E>
65where
66 T: Clone,
67{
68 fn clone(&self) -> Self {
69 Self {
70 inner: self.inner.clone(),
71 }
72 }
73}
74impl<T, E> Drop for JitInner<T, E> {
75 fn drop(&mut self) {
76 self.handle.get_mut().abort();
77 }
78}
79impl<T, E> JitCore<T, E>
80where
81 T: MaybeSend + 'static,
82 E: MaybeSend + 'static + fmt::Display,
83{
84 pub fn new_try<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self
89 where
90 Fut: Future<Output = std::result::Result<T, E>> + 'static + MaybeSend,
91 {
92 let handle = crate::runtime::spawn("jit-value", async { f().await });
93
94 Self {
95 inner: JitInner {
96 handle: handle.into(),
97 val: sync::OnceCell::new(),
98 }
99 .into(),
100 }
101 }
102
103 pub async fn get_try(&self) -> Result<&T, OneTimeError<E>> {
106 let mut init_error = None;
107 let value = self
108 .inner
109 .val
110 .get_or_init(|| async {
111 let handle: &mut _ = &mut *self.inner.handle.lock().await;
112 match handle.await {
113 Ok(Ok(o)) => Ok(o),
114 Ok(Err(err)) => {
115 let err_str = err.to_string();
116 init_error = Some(err);
117 Err(err_str)
118 },
119 Err(err) => {
120
121 #[cfg(not(target_family = "wasm"))]
122 if err.is_panic() {
123 warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value panicked");
124 panic::resume_unwind(err.into_panic());
126 }
127 #[cfg(not(target_family = "wasm"))]
128 if err.is_cancelled() {
129 warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value task canceled:");
130 }
131 Err(format!("Jit value {} failed unexpectedly with: {}", std::any::type_name::<T>(), err))
132 },
133 }
134 })
135 .await;
136 if let Some(err) = init_error {
137 return Err(OneTimeError::Original(err));
138 }
139 value
140 .as_ref()
141 .map_err(|err_str| OneTimeError::Copy(anyhow::Error::msg(err_str.to_owned())))
142 }
143}
144impl<T> JitCore<T, Infallible>
145where
146 T: MaybeSend + 'static,
147{
148 pub fn new<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self
149 where
150 Fut: Future<Output = T> + 'static + MaybeSend,
151 T: 'static,
152 {
153 Self::new_try(|| async { Ok(f().await) })
154 }
155
156 pub async fn get(&self) -> &T {
157 self.get_try().await.expect("can't fail")
158 }
159}
160#[cfg(test)]
161mod tests {
162 use std::time::Duration;
163
164 use anyhow::bail;
165
166 use super::{Jit, JitTry, JitTryAnyhow};
167
168 #[test_log::test(tokio::test)]
169 async fn sanity_jit() {
170 let v = Jit::new(|| async {
171 fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
172 3
173 });
174
175 assert_eq!(*v.get().await, 3);
176 assert_eq!(*v.get().await, 3);
177 assert_eq!(*v.clone().get().await, 3);
178 }
179
180 #[test_log::test(tokio::test)]
181 async fn sanity_jit_try_ok() {
182 let v = JitTryAnyhow::new_try(|| async {
183 fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
184 Ok(3)
185 });
186
187 assert_eq!(*v.get_try().await.expect("ok"), 3);
188 assert_eq!(*v.get_try().await.expect("ok"), 3);
189 assert_eq!(*v.clone().get_try().await.expect("ok"), 3);
190 }
191
192 #[test_log::test(tokio::test)]
193 async fn sanity_jit_try_err() {
194 let v = JitTry::new_try(|| async {
195 fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
196 bail!("BOOM");
197 #[allow(unreachable_code)]
198 Ok(3)
199 });
200
201 assert!(v.get_try().await.is_err());
202 assert!(v.get_try().await.is_err());
203 assert!(v.clone().get_try().await.is_err());
204 }
205}