1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{pin, Pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use thiserror::Error;
20use tokio::sync::{oneshot, watch};
21use tracing::{debug, error, info, trace};
22
23use crate::runtime;
24pub use crate::runtime::*;
27#[derive(Clone, Default, Debug)]
38pub struct TaskGroup {
39 inner: Arc<TaskGroupInner>,
40}
41
42impl TaskGroup {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn make_handle(&self) -> TaskHandle {
48 TaskHandle {
49 inner: self.inner.clone(),
50 }
51 }
52
53 pub fn make_subgroup(&self) -> Self {
66 let new_tg = Self::new();
67 self.inner.add_subgroup(new_tg.clone());
68 new_tg
69 }
70
71 pub fn shutdown(&self) {
74 self.inner.shutdown();
75 }
76
77 pub async fn shutdown_join_all(
79 self,
80 join_timeout: impl Into<Option<Duration>>,
81 ) -> Result<(), anyhow::Error> {
82 self.shutdown();
83 self.join_all(join_timeout.into()).await
84 }
85
86 #[cfg(not(target_family = "wasm"))]
89 pub fn install_kill_handler(&self) {
90 async fn wait_for_shutdown_signal() {
92 use tokio::signal;
93
94 let ctrl_c = async {
95 signal::ctrl_c()
96 .await
97 .expect("failed to install Ctrl+C handler");
98 };
99
100 #[cfg(unix)]
101 let terminate = async {
102 signal::unix::signal(signal::unix::SignalKind::terminate())
103 .expect("failed to install signal handler")
104 .recv()
105 .await;
106 };
107
108 #[cfg(not(unix))]
109 let terminate = std::future::pending::<()>();
110
111 tokio::select! {
112 () = ctrl_c => {},
113 () = terminate => {},
114 }
115 }
116
117 runtime::spawn("kill handlers", {
118 let task_group = self.clone();
119 async move {
120 wait_for_shutdown_signal().await;
121 info!(
122 target: LOG_TASK,
123 "signal received, starting graceful shutdown"
124 );
125 task_group.shutdown();
126 }
127 });
128 }
129
130 pub fn spawn<Fut, R>(
131 &self,
132 name: impl Into<String>,
133 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
134 ) -> oneshot::Receiver<R>
135 where
136 Fut: Future<Output = R> + MaybeSend + 'static,
137 R: MaybeSend + 'static,
138 {
139 self.spawn_inner(name, f, false)
140 }
141
142 pub fn spawn_silent<Fut, R>(
146 &self,
147 name: impl Into<String>,
148 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
149 ) -> oneshot::Receiver<R>
150 where
151 Fut: Future<Output = R> + MaybeSend + 'static,
152 R: MaybeSend + 'static,
153 {
154 self.spawn_inner(name, f, true)
155 }
156
157 fn spawn_inner<Fut, R>(
158 &self,
159 name: impl Into<String>,
160 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
161 quiet: bool,
162 ) -> oneshot::Receiver<R>
163 where
164 Fut: Future<Output = R> + MaybeSend + 'static,
165 R: MaybeSend + 'static,
166 {
167 let name = name.into();
168 let mut guard = TaskPanicGuard {
169 name: name.clone(),
170 inner: self.inner.clone(),
171 completed: false,
172 };
173 let handle = self.make_handle();
174
175 let (tx, rx) = oneshot::channel();
176 let handle = crate::runtime::spawn(&name, {
177 let name = name.clone();
178 async move {
179 if quiet {
181 trace!(target: LOG_TASK, "Starting task {name}");
182 } else {
183 debug!(target: LOG_TASK, "Starting task {name}");
184 }
185 let r = f(handle).await;
187 if quiet {
188 trace!(target: LOG_TASK, "Finished task {name}");
189 } else {
190 debug!(target: LOG_TASK, "Finished task {name}");
191 }
192 let _ = tx.send(r);
193 }
194 });
195 self.inner.add_join_handle(name, handle);
196 guard.completed = true;
197
198 rx
199 }
200
201 pub fn spawn_cancellable<R>(
204 &self,
205 name: impl Into<String>,
206 future: impl Future<Output = R> + MaybeSend + 'static,
207 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
208 where
209 R: MaybeSend + 'static,
210 {
211 self.spawn(name, |handle| async move {
212 let value = handle.cancel_on_shutdown(future).await;
213 if value.is_err() {
214 debug!(target: LOG_TASK, "task cancelled on shutdown");
216 }
217 value
218 })
219 }
220
221 pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
222 let deadline = timeout.map(|timeout| now() + timeout);
223 let mut errors = vec![];
224
225 self.join_all_inner(deadline, &mut errors).await;
226
227 if errors.is_empty() {
228 Ok(())
229 } else {
230 let num_errors = errors.len();
231 bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
232 }
233 }
234
235 #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
236 #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
237 pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
238 self.inner.join_all(deadline, errors).await;
239 }
240}
241
242struct TaskPanicGuard {
243 name: String,
244 inner: Arc<TaskGroupInner>,
245 completed: bool,
247}
248
249impl Drop for TaskPanicGuard {
250 fn drop(&mut self) {
251 if !self.completed {
252 info!(
253 target: LOG_TASK,
254 "Task {} shut down uncleanly. Shutting down task group.", self.name
255 );
256 self.inner.shutdown();
257 }
258 }
259}
260
261#[derive(Clone, Debug)]
262pub struct TaskHandle {
263 inner: Arc<TaskGroupInner>,
264}
265
266#[derive(thiserror::Error, Debug, Clone)]
267#[error("Task group is shutting down")]
268#[non_exhaustive]
269pub struct ShuttingDownError {}
270
271impl TaskHandle {
272 pub fn is_shutting_down(&self) -> bool {
276 self.inner.is_shutting_down()
277 }
278
279 pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
284 self.inner.make_shutdown_rx()
285 }
286
287 pub async fn cancel_on_shutdown<F: Future>(
289 &self,
290 fut: F,
291 ) -> Result<F::Output, ShuttingDownError> {
292 let rx = self.make_shutdown_rx();
293 match future::select(pin!(rx), pin!(fut)).await {
294 Either::Left(((), _)) => Err(ShuttingDownError {}),
295 Either::Right((value, _)) => Ok(value),
296 }
297 }
298}
299
300pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
301
302impl TaskShutdownToken {
303 fn new(mut rx: watch::Receiver<bool>) -> Self {
304 Self(Box::pin(async move {
305 let _ = rx.wait_for(|v| *v).await;
306 }))
307 }
308}
309
310impl Future for TaskShutdownToken {
311 type Output = ();
312
313 fn poll(
314 mut self: Pin<&mut Self>,
315 cx: &mut std::task::Context<'_>,
316 ) -> std::task::Poll<Self::Output> {
317 self.0.as_mut().poll(cx)
318 }
319}
320
321#[macro_export]
338macro_rules! async_trait_maybe_send {
339 ($($tt:tt)*) => {
340 #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
341 #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
342 $($tt)*
343 };
344}
345
346#[cfg(not(target_family = "wasm"))]
357#[macro_export]
358macro_rules! maybe_add_send {
359 ($($tt:tt)*) => {
360 $($tt)* + Send
361 };
362}
363
364#[cfg(target_family = "wasm")]
372#[macro_export]
373macro_rules! maybe_add_send {
374 ($($tt:tt)*) => {
375 $($tt)*
376 };
377}
378
379#[cfg(not(target_family = "wasm"))]
381#[macro_export]
382macro_rules! maybe_add_send_sync {
383 ($($tt:tt)*) => {
384 $($tt)* + Send + Sync
385 };
386}
387
388#[cfg(target_family = "wasm")]
390#[macro_export]
391macro_rules! maybe_add_send_sync {
392 ($($tt:tt)*) => {
393 $($tt)*
394 };
395}
396
397#[cfg(target_family = "wasm")]
402pub trait MaybeSend {}
403
404#[cfg(not(target_family = "wasm"))]
409pub trait MaybeSend: Send {}
410
411#[cfg(not(target_family = "wasm"))]
412impl<T: Send> MaybeSend for T {}
413
414#[cfg(target_family = "wasm")]
415impl<T> MaybeSend for T {}
416
417#[cfg(target_family = "wasm")]
419pub trait MaybeSync {}
420
421#[cfg(not(target_family = "wasm"))]
423pub trait MaybeSync: Sync {}
424
425#[cfg(not(target_family = "wasm"))]
426impl<T: Sync> MaybeSync for T {}
427
428#[cfg(target_family = "wasm")]
429impl<T> MaybeSync for T {}
430
431pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
434 info!(
435 target: LOG_TEST,
436 "Sleeping for {}.{:03} seconds because: {}",
437 duration.as_secs(),
438 duration.subsec_millis(),
439 comment.as_ref()
440 );
441 sleep(duration).await;
442}
443
444#[derive(Error, Debug)]
446#[error("Operation cancelled")]
447pub struct Cancelled;
448
449pub type Cancellable<T> = std::result::Result<T, Cancelled>;
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test_log::test(tokio::test)]
458 async fn shutdown_task_group_after() -> anyhow::Result<()> {
459 let tg = TaskGroup::new();
460 tg.spawn("shutdown waiter", |handle| async move {
461 handle.make_shutdown_rx().await;
462 });
463 sleep(Duration::from_millis(10)).await;
464 tg.shutdown_join_all(None).await?;
465 Ok(())
466 }
467
468 #[test_log::test(tokio::test)]
469 async fn shutdown_task_group_before() -> anyhow::Result<()> {
470 let tg = TaskGroup::new();
471 tg.spawn("shutdown waiter", |handle| async move {
472 sleep(Duration::from_millis(10)).await;
473 handle.make_shutdown_rx().await;
474 });
475 tg.shutdown_join_all(None).await?;
476 Ok(())
477 }
478
479 #[test_log::test(tokio::test)]
480 async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
481 let tg = TaskGroup::new();
482 tg.make_subgroup()
483 .spawn("shutdown waiter", |handle| async move {
484 handle.make_shutdown_rx().await;
485 });
486 sleep(Duration::from_millis(10)).await;
487 tg.shutdown_join_all(None).await?;
488 Ok(())
489 }
490
491 #[test_log::test(tokio::test)]
492 async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
493 let tg = TaskGroup::new();
494 tg.make_subgroup()
495 .spawn("shutdown waiter", |handle| async move {
496 sleep(Duration::from_millis(10)).await;
497 handle.make_shutdown_rx().await;
498 });
499 tg.shutdown_join_all(None).await?;
500 Ok(())
501 }
502}