fedimint_core/
task.rs

1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5/// Just-in-time initialization
6pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{pin, Pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use thiserror::Error;
20use tokio::sync::{oneshot, watch};
21use tracing::{debug, error, info};
22
23use crate::runtime;
24// TODO: stop using `task::*`, and use `runtime::*` in the code
25// lots of churn though
26pub use crate::runtime::*;
27/// A group of task working together
28///
29/// Using this struct it is possible to spawn one or more
30/// main thread collaborating, which can cooperatively gracefully
31/// shut down, either due to external request, or failure of
32/// one of them.
33///
34/// Each thread should periodically check [`TaskHandle`] or rely
35/// on condition like channel disconnection to detect when it is time
36/// to finish.
37#[derive(Clone, Default, Debug)]
38pub struct TaskGroup {
39    inner: Arc<TaskGroupInner>,
40}
41
42impl TaskGroup {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn make_handle(&self) -> TaskHandle {
48        TaskHandle {
49            inner: self.inner.clone(),
50        }
51    }
52
53    /// Create a sub-group
54    ///
55    /// Task subgroup works like an independent [`TaskGroup`], but the parent
56    /// `TaskGroup` will propagate the shut down signal to a sub-group.
57    ///
58    /// In contrast to using the parent group directly, a subgroup allows
59    /// calling [`Self::join_all`] and detecting any panics on just a
60    /// subset of tasks.
61    ///
62    /// The code create a subgroup is responsible for calling
63    /// [`Self::join_all`]. If it won't, the parent subgroup **will not**
64    /// detect any panics in the tasks spawned by the subgroup.
65    pub fn make_subgroup(&self) -> Self {
66        let new_tg = Self::new();
67        self.inner.add_subgroup(new_tg.clone());
68        new_tg
69    }
70
71    /// Tell all tasks in the group to shut down. This only initiates the
72    /// shutdown process, it does not wait for the tasks to shut down.
73    pub fn shutdown(&self) {
74        self.inner.shutdown();
75    }
76
77    /// Tell all tasks in the group to shut down and wait for them to finish.
78    pub async fn shutdown_join_all(
79        self,
80        join_timeout: impl Into<Option<Duration>>,
81    ) -> Result<(), anyhow::Error> {
82        self.shutdown();
83        self.join_all(join_timeout.into()).await
84    }
85
86    /// Add a task to the group that waits for CTRL+C or SIGTERM, then
87    /// tells the rest of the task group to shut down.
88    #[cfg(not(target_family = "wasm"))]
89    pub fn install_kill_handler(&self) {
90        /// Wait for CTRL+C or SIGTERM.
91        async fn wait_for_shutdown_signal() {
92            use tokio::signal;
93
94            let ctrl_c = async {
95                signal::ctrl_c()
96                    .await
97                    .expect("failed to install Ctrl+C handler");
98            };
99
100            #[cfg(unix)]
101            let terminate = async {
102                signal::unix::signal(signal::unix::SignalKind::terminate())
103                    .expect("failed to install signal handler")
104                    .recv()
105                    .await;
106            };
107
108            #[cfg(not(unix))]
109            let terminate = std::future::pending::<()>();
110
111            tokio::select! {
112                () = ctrl_c => {},
113                () = terminate => {},
114            }
115        }
116
117        runtime::spawn("kill handlers", {
118            let task_group = self.clone();
119            async move {
120                wait_for_shutdown_signal().await;
121                info!(
122                    target: LOG_TASK,
123                    "signal received, starting graceful shutdown"
124                );
125                task_group.shutdown();
126            }
127        });
128    }
129
130    pub fn spawn<Fut, R>(
131        &self,
132        name: impl Into<String>,
133        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
134    ) -> oneshot::Receiver<R>
135    where
136        Fut: Future<Output = R> + MaybeSend + 'static,
137        R: MaybeSend + 'static,
138    {
139        let name = name.into();
140        let mut guard = TaskPanicGuard {
141            name: name.clone(),
142            inner: self.inner.clone(),
143            completed: false,
144        };
145        let handle = self.make_handle();
146
147        let (tx, rx) = oneshot::channel();
148        let handle = crate::runtime::spawn(&name, {
149            let name = name.clone();
150            async move {
151                // if receiver is not interested, just drop the message
152                debug!("Starting task {name}");
153                let r = f(handle).await;
154                debug!("Finished task {name}");
155                let _ = tx.send(r);
156            }
157        });
158        self.inner.add_join_handle(name, handle);
159        guard.completed = true;
160
161        rx
162    }
163
164    /// Spawn a task that will get cancelled automatically on `TaskGroup`
165    /// shutdown.
166    pub fn spawn_cancellable<R>(
167        &self,
168        name: impl Into<String>,
169        future: impl Future<Output = R> + MaybeSend + 'static,
170    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
171    where
172        R: MaybeSend + 'static,
173    {
174        self.spawn(name, |handle| async move {
175            let value = handle.cancel_on_shutdown(future).await;
176            if value.is_err() {
177                // name will part of span
178                debug!("task cancelled on shutdown");
179            }
180            value
181        })
182    }
183
184    pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
185        let deadline = timeout.map(|timeout| now() + timeout);
186        let mut errors = vec![];
187
188        self.join_all_inner(deadline, &mut errors).await;
189
190        if errors.is_empty() {
191            Ok(())
192        } else {
193            let num_errors = errors.len();
194            bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
195        }
196    }
197
198    #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
199    #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
200    pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
201        self.inner.join_all(deadline, errors).await;
202    }
203}
204
205struct TaskPanicGuard {
206    name: String,
207    inner: Arc<TaskGroupInner>,
208    /// Did the future completed successfully (no panic)
209    completed: bool,
210}
211
212impl Drop for TaskPanicGuard {
213    fn drop(&mut self) {
214        if !self.completed {
215            info!(
216                target: LOG_TASK,
217                "Task {} shut down uncleanly. Shutting down task group.", self.name
218            );
219            self.inner.shutdown();
220        }
221    }
222}
223
224#[derive(Clone, Debug)]
225pub struct TaskHandle {
226    inner: Arc<TaskGroupInner>,
227}
228
229#[derive(thiserror::Error, Debug, Clone)]
230#[error("Task group is shutting down")]
231#[non_exhaustive]
232pub struct ShuttingDownError {}
233
234impl TaskHandle {
235    /// Is task group shutting down?
236    ///
237    /// Every task in a task group should detect and stop if `true`.
238    pub fn is_shutting_down(&self) -> bool {
239        self.inner.is_shutting_down()
240    }
241
242    /// Make a [`oneshot::Receiver`] that will fire on shutdown
243    ///
244    /// Tasks can use `select` on the return value to handle shutdown
245    /// signal during otherwise blocking operation.
246    pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
247        self.inner.make_shutdown_rx()
248    }
249
250    /// Run the future or cancel it if the [`TaskGroup`] shuts down.
251    pub async fn cancel_on_shutdown<F: Future>(
252        &self,
253        fut: F,
254    ) -> Result<F::Output, ShuttingDownError> {
255        let rx = self.make_shutdown_rx();
256        match future::select(pin!(rx), pin!(fut)).await {
257            Either::Left(((), _)) => Err(ShuttingDownError {}),
258            Either::Right((value, _)) => Ok(value),
259        }
260    }
261}
262
263pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
264
265impl TaskShutdownToken {
266    fn new(mut rx: watch::Receiver<bool>) -> Self {
267        Self(Box::pin(async move {
268            let _ = rx.wait_for(|v| *v).await;
269        }))
270    }
271}
272
273impl Future for TaskShutdownToken {
274    type Output = ();
275
276    fn poll(
277        mut self: Pin<&mut Self>,
278        cx: &mut std::task::Context<'_>,
279    ) -> std::task::Poll<Self::Output> {
280        self.0.as_mut().poll(cx)
281    }
282}
283
284/// async trait that use MaybeSend
285///
286/// # Example
287///
288/// ```rust
289/// use fedimint_core::{apply, async_trait_maybe_send};
290/// #[apply(async_trait_maybe_send!)]
291/// trait Foo {
292///     // methods
293/// }
294///
295/// #[apply(async_trait_maybe_send!)]
296/// impl Foo for () {
297///     // methods
298/// }
299/// ```
300#[macro_export]
301macro_rules! async_trait_maybe_send {
302    ($($tt:tt)*) => {
303        #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
304        #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
305        $($tt)*
306    };
307}
308
309/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
310///
311/// # Example
312///
313/// ```rust
314/// use std::any::Any;
315///
316/// use fedimint_core::{apply, maybe_add_send};
317/// type Foo = maybe_add_send!(dyn Any);
318/// ```
319#[cfg(not(target_family = "wasm"))]
320#[macro_export]
321macro_rules! maybe_add_send {
322    ($($tt:tt)*) => {
323        $($tt)* + Send
324    };
325}
326
327/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
328///
329/// # Example
330///
331/// ```rust
332/// type Foo = maybe_add_send!(dyn Any);
333/// ```
334#[cfg(target_family = "wasm")]
335#[macro_export]
336macro_rules! maybe_add_send {
337    ($($tt:tt)*) => {
338        $($tt)*
339    };
340}
341
342/// See `maybe_add_send`
343#[cfg(not(target_family = "wasm"))]
344#[macro_export]
345macro_rules! maybe_add_send_sync {
346    ($($tt:tt)*) => {
347        $($tt)* + Send + Sync
348    };
349}
350
351/// See `maybe_add_send`
352#[cfg(target_family = "wasm")]
353#[macro_export]
354macro_rules! maybe_add_send_sync {
355    ($($tt:tt)*) => {
356        $($tt)*
357    };
358}
359
360/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
361///
362/// On wasm, most types don't implement `Send` because JS types can not sent
363/// between workers directly.
364#[cfg(target_family = "wasm")]
365pub trait MaybeSend {}
366
367/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
368///
369/// On wasm, most types don't implement `Send` because JS types can not sent
370/// between workers directly.
371#[cfg(not(target_family = "wasm"))]
372pub trait MaybeSend: Send {}
373
374#[cfg(not(target_family = "wasm"))]
375impl<T: Send> MaybeSend for T {}
376
377#[cfg(target_family = "wasm")]
378impl<T> MaybeSend for T {}
379
380/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
381#[cfg(target_family = "wasm")]
382pub trait MaybeSync {}
383
384/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
385#[cfg(not(target_family = "wasm"))]
386pub trait MaybeSync: Sync {}
387
388#[cfg(not(target_family = "wasm"))]
389impl<T: Sync> MaybeSync for T {}
390
391#[cfg(target_family = "wasm")]
392impl<T> MaybeSync for T {}
393
394// Used in tests when sleep functionality is desired so it can be logged.
395// Must include comment describing the reason for sleeping.
396pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
397    info!(
398        target: LOG_TEST,
399        "Sleeping for {}.{:03} seconds because: {}",
400        duration.as_secs(),
401        duration.subsec_millis(),
402        comment.as_ref()
403    );
404    sleep(duration).await;
405}
406
407/// An error used as a "cancelled" marker in [`Cancellable`].
408#[derive(Error, Debug)]
409#[error("Operation cancelled")]
410pub struct Cancelled;
411
412/// Operation that can potentially get cancelled returning no result (e.g.
413/// program shutdown).
414pub type Cancellable<T> = std::result::Result<T, Cancelled>;
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test_log::test(tokio::test)]
421    async fn shutdown_task_group_after() -> anyhow::Result<()> {
422        let tg = TaskGroup::new();
423        tg.spawn("shutdown waiter", |handle| async move {
424            handle.make_shutdown_rx().await;
425        });
426        sleep(Duration::from_millis(10)).await;
427        tg.shutdown_join_all(None).await?;
428        Ok(())
429    }
430
431    #[test_log::test(tokio::test)]
432    async fn shutdown_task_group_before() -> anyhow::Result<()> {
433        let tg = TaskGroup::new();
434        tg.spawn("shutdown waiter", |handle| async move {
435            sleep(Duration::from_millis(10)).await;
436            handle.make_shutdown_rx().await;
437        });
438        tg.shutdown_join_all(None).await?;
439        Ok(())
440    }
441
442    #[test_log::test(tokio::test)]
443    async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
444        let tg = TaskGroup::new();
445        tg.make_subgroup()
446            .spawn("shutdown waiter", |handle| async move {
447                handle.make_shutdown_rx().await;
448            });
449        sleep(Duration::from_millis(10)).await;
450        tg.shutdown_join_all(None).await?;
451        Ok(())
452    }
453
454    #[test_log::test(tokio::test)]
455    async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
456        let tg = TaskGroup::new();
457        tg.make_subgroup()
458            .spawn("shutdown waiter", |handle| async move {
459                sleep(Duration::from_millis(10)).await;
460                handle.make_shutdown_rx().await;
461            });
462        tg.shutdown_join_all(None).await?;
463        Ok(())
464    }
465}