1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{pin, Pin};
11use std::sync::Arc;
12use std::time::{Duration, SystemTime};
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use thiserror::Error;
20use tokio::sync::{oneshot, watch};
21use tracing::{debug, error, info};
22
23use crate::runtime;
24pub use crate::runtime::*;
27#[derive(Clone, Default, Debug)]
38pub struct TaskGroup {
39 inner: Arc<TaskGroupInner>,
40}
41
42impl TaskGroup {
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn make_handle(&self) -> TaskHandle {
48 TaskHandle {
49 inner: self.inner.clone(),
50 }
51 }
52
53 pub fn make_subgroup(&self) -> Self {
66 let new_tg = Self::new();
67 self.inner.add_subgroup(new_tg.clone());
68 new_tg
69 }
70
71 pub fn shutdown(&self) {
74 self.inner.shutdown();
75 }
76
77 pub async fn shutdown_join_all(
79 self,
80 join_timeout: impl Into<Option<Duration>>,
81 ) -> Result<(), anyhow::Error> {
82 self.shutdown();
83 self.join_all(join_timeout.into()).await
84 }
85
86 #[cfg(not(target_family = "wasm"))]
89 pub fn install_kill_handler(&self) {
90 async fn wait_for_shutdown_signal() {
92 use tokio::signal;
93
94 let ctrl_c = async {
95 signal::ctrl_c()
96 .await
97 .expect("failed to install Ctrl+C handler");
98 };
99
100 #[cfg(unix)]
101 let terminate = async {
102 signal::unix::signal(signal::unix::SignalKind::terminate())
103 .expect("failed to install signal handler")
104 .recv()
105 .await;
106 };
107
108 #[cfg(not(unix))]
109 let terminate = std::future::pending::<()>();
110
111 tokio::select! {
112 () = ctrl_c => {},
113 () = terminate => {},
114 }
115 }
116
117 runtime::spawn("kill handlers", {
118 let task_group = self.clone();
119 async move {
120 wait_for_shutdown_signal().await;
121 info!(
122 target: LOG_TASK,
123 "signal received, starting graceful shutdown"
124 );
125 task_group.shutdown();
126 }
127 });
128 }
129
130 pub fn spawn<Fut, R>(
131 &self,
132 name: impl Into<String>,
133 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
134 ) -> oneshot::Receiver<R>
135 where
136 Fut: Future<Output = R> + MaybeSend + 'static,
137 R: MaybeSend + 'static,
138 {
139 let name = name.into();
140 let mut guard = TaskPanicGuard {
141 name: name.clone(),
142 inner: self.inner.clone(),
143 completed: false,
144 };
145 let handle = self.make_handle();
146
147 let (tx, rx) = oneshot::channel();
148 let handle = crate::runtime::spawn(&name, {
149 let name = name.clone();
150 async move {
151 debug!("Starting task {name}");
153 let r = f(handle).await;
154 debug!("Finished task {name}");
155 let _ = tx.send(r);
156 }
157 });
158 self.inner.add_join_handle(name, handle);
159 guard.completed = true;
160
161 rx
162 }
163
164 pub fn spawn_cancellable<R>(
167 &self,
168 name: impl Into<String>,
169 future: impl Future<Output = R> + MaybeSend + 'static,
170 ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
171 where
172 R: MaybeSend + 'static,
173 {
174 self.spawn(name, |handle| async move {
175 let value = handle.cancel_on_shutdown(future).await;
176 if value.is_err() {
177 debug!("task cancelled on shutdown");
179 }
180 value
181 })
182 }
183
184 pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
185 let deadline = timeout.map(|timeout| now() + timeout);
186 let mut errors = vec![];
187
188 self.join_all_inner(deadline, &mut errors).await;
189
190 if errors.is_empty() {
191 Ok(())
192 } else {
193 let num_errors = errors.len();
194 bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
195 }
196 }
197
198 #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
199 #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
200 pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
201 self.inner.join_all(deadline, errors).await;
202 }
203}
204
205struct TaskPanicGuard {
206 name: String,
207 inner: Arc<TaskGroupInner>,
208 completed: bool,
210}
211
212impl Drop for TaskPanicGuard {
213 fn drop(&mut self) {
214 if !self.completed {
215 info!(
216 target: LOG_TASK,
217 "Task {} shut down uncleanly. Shutting down task group.", self.name
218 );
219 self.inner.shutdown();
220 }
221 }
222}
223
224#[derive(Clone, Debug)]
225pub struct TaskHandle {
226 inner: Arc<TaskGroupInner>,
227}
228
229#[derive(thiserror::Error, Debug, Clone)]
230#[error("Task group is shutting down")]
231#[non_exhaustive]
232pub struct ShuttingDownError {}
233
234impl TaskHandle {
235 pub fn is_shutting_down(&self) -> bool {
239 self.inner.is_shutting_down()
240 }
241
242 pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
247 self.inner.make_shutdown_rx()
248 }
249
250 pub async fn cancel_on_shutdown<F: Future>(
252 &self,
253 fut: F,
254 ) -> Result<F::Output, ShuttingDownError> {
255 let rx = self.make_shutdown_rx();
256 match future::select(pin!(rx), pin!(fut)).await {
257 Either::Left(((), _)) => Err(ShuttingDownError {}),
258 Either::Right((value, _)) => Ok(value),
259 }
260 }
261}
262
263pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
264
265impl TaskShutdownToken {
266 fn new(mut rx: watch::Receiver<bool>) -> Self {
267 Self(Box::pin(async move {
268 let _ = rx.wait_for(|v| *v).await;
269 }))
270 }
271}
272
273impl Future for TaskShutdownToken {
274 type Output = ();
275
276 fn poll(
277 mut self: Pin<&mut Self>,
278 cx: &mut std::task::Context<'_>,
279 ) -> std::task::Poll<Self::Output> {
280 self.0.as_mut().poll(cx)
281 }
282}
283
284#[macro_export]
301macro_rules! async_trait_maybe_send {
302 ($($tt:tt)*) => {
303 #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
304 #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
305 $($tt)*
306 };
307}
308
309#[cfg(not(target_family = "wasm"))]
320#[macro_export]
321macro_rules! maybe_add_send {
322 ($($tt:tt)*) => {
323 $($tt)* + Send
324 };
325}
326
327#[cfg(target_family = "wasm")]
335#[macro_export]
336macro_rules! maybe_add_send {
337 ($($tt:tt)*) => {
338 $($tt)*
339 };
340}
341
342#[cfg(not(target_family = "wasm"))]
344#[macro_export]
345macro_rules! maybe_add_send_sync {
346 ($($tt:tt)*) => {
347 $($tt)* + Send + Sync
348 };
349}
350
351#[cfg(target_family = "wasm")]
353#[macro_export]
354macro_rules! maybe_add_send_sync {
355 ($($tt:tt)*) => {
356 $($tt)*
357 };
358}
359
360#[cfg(target_family = "wasm")]
365pub trait MaybeSend {}
366
367#[cfg(not(target_family = "wasm"))]
372pub trait MaybeSend: Send {}
373
374#[cfg(not(target_family = "wasm"))]
375impl<T: Send> MaybeSend for T {}
376
377#[cfg(target_family = "wasm")]
378impl<T> MaybeSend for T {}
379
380#[cfg(target_family = "wasm")]
382pub trait MaybeSync {}
383
384#[cfg(not(target_family = "wasm"))]
386pub trait MaybeSync: Sync {}
387
388#[cfg(not(target_family = "wasm"))]
389impl<T: Sync> MaybeSync for T {}
390
391#[cfg(target_family = "wasm")]
392impl<T> MaybeSync for T {}
393
394pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
397 info!(
398 target: LOG_TEST,
399 "Sleeping for {}.{:03} seconds because: {}",
400 duration.as_secs(),
401 duration.subsec_millis(),
402 comment.as_ref()
403 );
404 sleep(duration).await;
405}
406
407#[derive(Error, Debug)]
409#[error("Operation cancelled")]
410pub struct Cancelled;
411
412pub type Cancellable<T> = std::result::Result<T, Cancelled>;
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test_log::test(tokio::test)]
421 async fn shutdown_task_group_after() -> anyhow::Result<()> {
422 let tg = TaskGroup::new();
423 tg.spawn("shutdown waiter", |handle| async move {
424 handle.make_shutdown_rx().await;
425 });
426 sleep(Duration::from_millis(10)).await;
427 tg.shutdown_join_all(None).await?;
428 Ok(())
429 }
430
431 #[test_log::test(tokio::test)]
432 async fn shutdown_task_group_before() -> anyhow::Result<()> {
433 let tg = TaskGroup::new();
434 tg.spawn("shutdown waiter", |handle| async move {
435 sleep(Duration::from_millis(10)).await;
436 handle.make_shutdown_rx().await;
437 });
438 tg.shutdown_join_all(None).await?;
439 Ok(())
440 }
441
442 #[test_log::test(tokio::test)]
443 async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
444 let tg = TaskGroup::new();
445 tg.make_subgroup()
446 .spawn("shutdown waiter", |handle| async move {
447 handle.make_shutdown_rx().await;
448 });
449 sleep(Duration::from_millis(10)).await;
450 tg.shutdown_join_all(None).await?;
451 Ok(())
452 }
453
454 #[test_log::test(tokio::test)]
455 async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
456 let tg = TaskGroup::new();
457 tg.make_subgroup()
458 .spawn("shutdown waiter", |handle| async move {
459 sleep(Duration::from_millis(10)).await;
460 handle.make_shutdown_rx().await;
461 });
462 tg.shutdown_join_all(None).await?;
463 Ok(())
464 }
465}