fedimint_core/
task.rs

1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5/// Just-in-time initialization
6pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{pin, Pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use thiserror::Error;
20use tokio::sync::{oneshot, watch};
21use tracing::{debug, error, info, trace};
22
23use crate::runtime;
24// TODO: stop using `task::*`, and use `runtime::*` in the code
25// lots of churn though
26pub use crate::runtime::*;
27/// A group of task working together
28///
29/// Using this struct it is possible to spawn one or more
30/// main thread collaborating, which can cooperatively gracefully
31/// shut down, either due to external request, or failure of
32/// one of them.
33///
34/// Each thread should periodically check [`TaskHandle`] or rely
35/// on condition like channel disconnection to detect when it is time
36/// to finish.
37#[derive(Clone, Default, Debug)]
38pub struct TaskGroup {
39    inner: Arc<TaskGroupInner>,
40}
41
42impl TaskGroup {
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    pub fn make_handle(&self) -> TaskHandle {
48        TaskHandle {
49            inner: self.inner.clone(),
50        }
51    }
52
53    /// Create a sub-group
54    ///
55    /// Task subgroup works like an independent [`TaskGroup`], but the parent
56    /// `TaskGroup` will propagate the shut down signal to a sub-group.
57    ///
58    /// In contrast to using the parent group directly, a subgroup allows
59    /// calling [`Self::join_all`] and detecting any panics on just a
60    /// subset of tasks.
61    ///
62    /// The code create a subgroup is responsible for calling
63    /// [`Self::join_all`]. If it won't, the parent subgroup **will not**
64    /// detect any panics in the tasks spawned by the subgroup.
65    pub fn make_subgroup(&self) -> Self {
66        let new_tg = Self::new();
67        self.inner.add_subgroup(new_tg.clone());
68        new_tg
69    }
70
71    /// Tell all tasks in the group to shut down. This only initiates the
72    /// shutdown process, it does not wait for the tasks to shut down.
73    pub fn shutdown(&self) {
74        self.inner.shutdown();
75    }
76
77    /// Tell all tasks in the group to shut down and wait for them to finish.
78    pub async fn shutdown_join_all(
79        self,
80        join_timeout: impl Into<Option<Duration>>,
81    ) -> Result<(), anyhow::Error> {
82        self.shutdown();
83        self.join_all(join_timeout.into()).await
84    }
85
86    /// Add a task to the group that waits for CTRL+C or SIGTERM, then
87    /// tells the rest of the task group to shut down.
88    #[cfg(not(target_family = "wasm"))]
89    pub fn install_kill_handler(&self) {
90        /// Wait for CTRL+C or SIGTERM.
91        async fn wait_for_shutdown_signal() {
92            use tokio::signal;
93
94            let ctrl_c = async {
95                signal::ctrl_c()
96                    .await
97                    .expect("failed to install Ctrl+C handler");
98            };
99
100            #[cfg(unix)]
101            let terminate = async {
102                signal::unix::signal(signal::unix::SignalKind::terminate())
103                    .expect("failed to install signal handler")
104                    .recv()
105                    .await;
106            };
107
108            #[cfg(not(unix))]
109            let terminate = std::future::pending::<()>();
110
111            tokio::select! {
112                () = ctrl_c => {},
113                () = terminate => {},
114            }
115        }
116
117        runtime::spawn("kill handlers", {
118            let task_group = self.clone();
119            async move {
120                wait_for_shutdown_signal().await;
121                info!(
122                    target: LOG_TASK,
123                    "signal received, starting graceful shutdown"
124                );
125                task_group.shutdown();
126            }
127        });
128    }
129
130    pub fn spawn<Fut, R>(
131        &self,
132        name: impl Into<String>,
133        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
134    ) -> oneshot::Receiver<R>
135    where
136        Fut: Future<Output = R> + MaybeSend + 'static,
137        R: MaybeSend + 'static,
138    {
139        self.spawn_inner(name, f, false)
140    }
141
142    /// This is a version of [`Self::spawn`] that uses less noisy logging level
143    ///
144    /// Meant for tasks that are spawned often enough to not be as interesting.
145    pub fn spawn_silent<Fut, R>(
146        &self,
147        name: impl Into<String>,
148        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
149    ) -> oneshot::Receiver<R>
150    where
151        Fut: Future<Output = R> + MaybeSend + 'static,
152        R: MaybeSend + 'static,
153    {
154        self.spawn_inner(name, f, true)
155    }
156
157    fn spawn_inner<Fut, R>(
158        &self,
159        name: impl Into<String>,
160        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
161        quiet: bool,
162    ) -> oneshot::Receiver<R>
163    where
164        Fut: Future<Output = R> + MaybeSend + 'static,
165        R: MaybeSend + 'static,
166    {
167        let name = name.into();
168        let mut guard = TaskPanicGuard {
169            name: name.clone(),
170            inner: self.inner.clone(),
171            completed: false,
172        };
173        let handle = self.make_handle();
174
175        let (tx, rx) = oneshot::channel();
176        let handle = crate::runtime::spawn(&name, {
177            let name = name.clone();
178            async move {
179                // Unfortunately log levels need to be static
180                if quiet {
181                    trace!(target: LOG_TASK, "Starting task {name}");
182                } else {
183                    debug!(target: LOG_TASK, "Starting task {name}");
184                }
185                // if receiver is not interested, just drop the message
186                let r = f(handle).await;
187                if quiet {
188                    trace!(target: LOG_TASK, "Finished task {name}");
189                } else {
190                    debug!(target: LOG_TASK, "Finished task {name}");
191                }
192                let _ = tx.send(r);
193            }
194        });
195        self.inner.add_join_handle(name, handle);
196        guard.completed = true;
197
198        rx
199    }
200
201    /// Spawn a task that will get cancelled automatically on `TaskGroup`
202    /// shutdown.
203    pub fn spawn_cancellable<R>(
204        &self,
205        name: impl Into<String>,
206        future: impl Future<Output = R> + MaybeSend + 'static,
207    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
208    where
209        R: MaybeSend + 'static,
210    {
211        self.spawn(name, |handle| async move {
212            let value = handle.cancel_on_shutdown(future).await;
213            if value.is_err() {
214                // name will part of span
215                debug!(target: LOG_TASK, "task cancelled on shutdown");
216            }
217            value
218        })
219    }
220
221    pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
222        let deadline = timeout.map(|timeout| now() + timeout);
223        let mut errors = vec![];
224
225        self.join_all_inner(deadline, &mut errors).await;
226
227        if errors.is_empty() {
228            Ok(())
229        } else {
230            let num_errors = errors.len();
231            bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
232        }
233    }
234
235    #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
236    #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
237    pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
238        self.inner.join_all(deadline, errors).await;
239    }
240}
241
242struct TaskPanicGuard {
243    name: String,
244    inner: Arc<TaskGroupInner>,
245    /// Did the future completed successfully (no panic)
246    completed: bool,
247}
248
249impl Drop for TaskPanicGuard {
250    fn drop(&mut self) {
251        if !self.completed {
252            info!(
253                target: LOG_TASK,
254                "Task {} shut down uncleanly. Shutting down task group.", self.name
255            );
256            self.inner.shutdown();
257        }
258    }
259}
260
261#[derive(Clone, Debug)]
262pub struct TaskHandle {
263    inner: Arc<TaskGroupInner>,
264}
265
266#[derive(thiserror::Error, Debug, Clone)]
267#[error("Task group is shutting down")]
268#[non_exhaustive]
269pub struct ShuttingDownError {}
270
271impl TaskHandle {
272    /// Is task group shutting down?
273    ///
274    /// Every task in a task group should detect and stop if `true`.
275    pub fn is_shutting_down(&self) -> bool {
276        self.inner.is_shutting_down()
277    }
278
279    /// Make a [`oneshot::Receiver`] that will fire on shutdown
280    ///
281    /// Tasks can use `select` on the return value to handle shutdown
282    /// signal during otherwise blocking operation.
283    pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
284        self.inner.make_shutdown_rx()
285    }
286
287    /// Run the future or cancel it if the [`TaskGroup`] shuts down.
288    pub async fn cancel_on_shutdown<F: Future>(
289        &self,
290        fut: F,
291    ) -> Result<F::Output, ShuttingDownError> {
292        let rx = self.make_shutdown_rx();
293        match future::select(pin!(rx), pin!(fut)).await {
294            Either::Left(((), _)) => Err(ShuttingDownError {}),
295            Either::Right((value, _)) => Ok(value),
296        }
297    }
298}
299
300pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
301
302impl TaskShutdownToken {
303    fn new(mut rx: watch::Receiver<bool>) -> Self {
304        Self(Box::pin(async move {
305            let _ = rx.wait_for(|v| *v).await;
306        }))
307    }
308}
309
310impl Future for TaskShutdownToken {
311    type Output = ();
312
313    fn poll(
314        mut self: Pin<&mut Self>,
315        cx: &mut std::task::Context<'_>,
316    ) -> std::task::Poll<Self::Output> {
317        self.0.as_mut().poll(cx)
318    }
319}
320
321/// async trait that use MaybeSend
322///
323/// # Example
324///
325/// ```rust
326/// use fedimint_core::{apply, async_trait_maybe_send};
327/// #[apply(async_trait_maybe_send!)]
328/// trait Foo {
329///     // methods
330/// }
331///
332/// #[apply(async_trait_maybe_send!)]
333/// impl Foo for () {
334///     // methods
335/// }
336/// ```
337#[macro_export]
338macro_rules! async_trait_maybe_send {
339    ($($tt:tt)*) => {
340        #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
341        #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
342        $($tt)*
343    };
344}
345
346/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
347///
348/// # Example
349///
350/// ```rust
351/// use std::any::Any;
352///
353/// use fedimint_core::{apply, maybe_add_send};
354/// type Foo = maybe_add_send!(dyn Any);
355/// ```
356#[cfg(not(target_family = "wasm"))]
357#[macro_export]
358macro_rules! maybe_add_send {
359    ($($tt:tt)*) => {
360        $($tt)* + Send
361    };
362}
363
364/// MaybeSync can not be used in `dyn $Trait + MaybeSend`
365///
366/// # Example
367///
368/// ```rust
369/// type Foo = maybe_add_send!(dyn Any);
370/// ```
371#[cfg(target_family = "wasm")]
372#[macro_export]
373macro_rules! maybe_add_send {
374    ($($tt:tt)*) => {
375        $($tt)*
376    };
377}
378
379/// See `maybe_add_send`
380#[cfg(not(target_family = "wasm"))]
381#[macro_export]
382macro_rules! maybe_add_send_sync {
383    ($($tt:tt)*) => {
384        $($tt)* + Send + Sync
385    };
386}
387
388/// See `maybe_add_send`
389#[cfg(target_family = "wasm")]
390#[macro_export]
391macro_rules! maybe_add_send_sync {
392    ($($tt:tt)*) => {
393        $($tt)*
394    };
395}
396
397/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
398///
399/// On wasm, most types don't implement `Send` because JS types can not sent
400/// between workers directly.
401#[cfg(target_family = "wasm")]
402pub trait MaybeSend {}
403
404/// `MaybeSend` is no-op on wasm and `Send` on non wasm.
405///
406/// On wasm, most types don't implement `Send` because JS types can not sent
407/// between workers directly.
408#[cfg(not(target_family = "wasm"))]
409pub trait MaybeSend: Send {}
410
411#[cfg(not(target_family = "wasm"))]
412impl<T: Send> MaybeSend for T {}
413
414#[cfg(target_family = "wasm")]
415impl<T> MaybeSend for T {}
416
417/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
418#[cfg(target_family = "wasm")]
419pub trait MaybeSync {}
420
421/// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
422#[cfg(not(target_family = "wasm"))]
423pub trait MaybeSync: Sync {}
424
425#[cfg(not(target_family = "wasm"))]
426impl<T: Sync> MaybeSync for T {}
427
428#[cfg(target_family = "wasm")]
429impl<T> MaybeSync for T {}
430
431// Used in tests when sleep functionality is desired so it can be logged.
432// Must include comment describing the reason for sleeping.
433pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
434    info!(
435        target: LOG_TEST,
436        "Sleeping for {}.{:03} seconds because: {}",
437        duration.as_secs(),
438        duration.subsec_millis(),
439        comment.as_ref()
440    );
441    sleep(duration).await;
442}
443
444/// An error used as a "cancelled" marker in [`Cancellable`].
445#[derive(Error, Debug)]
446#[error("Operation cancelled")]
447pub struct Cancelled;
448
449/// Operation that can potentially get cancelled returning no result (e.g.
450/// program shutdown).
451pub type Cancellable<T> = std::result::Result<T, Cancelled>;
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test_log::test(tokio::test)]
458    async fn shutdown_task_group_after() -> anyhow::Result<()> {
459        let tg = TaskGroup::new();
460        tg.spawn("shutdown waiter", |handle| async move {
461            handle.make_shutdown_rx().await;
462        });
463        sleep(Duration::from_millis(10)).await;
464        tg.shutdown_join_all(None).await?;
465        Ok(())
466    }
467
468    #[test_log::test(tokio::test)]
469    async fn shutdown_task_group_before() -> anyhow::Result<()> {
470        let tg = TaskGroup::new();
471        tg.spawn("shutdown waiter", |handle| async move {
472            sleep(Duration::from_millis(10)).await;
473            handle.make_shutdown_rx().await;
474        });
475        tg.shutdown_join_all(None).await?;
476        Ok(())
477    }
478
479    #[test_log::test(tokio::test)]
480    async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
481        let tg = TaskGroup::new();
482        tg.make_subgroup()
483            .spawn("shutdown waiter", |handle| async move {
484                handle.make_shutdown_rx().await;
485            });
486        sleep(Duration::from_millis(10)).await;
487        tg.shutdown_join_all(None).await?;
488        Ok(())
489    }
490
491    #[test_log::test(tokio::test)]
492    async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
493        let tg = TaskGroup::new();
494        tg.make_subgroup()
495            .spawn("shutdown waiter", |handle| async move {
496                sleep(Duration::from_millis(10)).await;
497                handle.make_shutdown_rx().await;
498            });
499        tg.shutdown_join_all(None).await?;
500        Ok(())
501    }
502}