waitgroup/
lib.rs

1//! A WaitGroup waits for a collection of task to finish.
2//!
3//! ## Examples
4//! ```rust
5//! use waitgroup::WaitGroup;
6//! use async_std::task;
7//! # task::block_on(
8//! async {
9//!     let wg = WaitGroup::new();
10//!     for _ in 0..100 {
11//!         let w = wg.worker();
12//!         task::spawn(async move {
13//!             // do work
14//!             drop(w); // drop w means task finished
15//!         });
16//!     }
17//!
18//!     wg.wait().await;
19//! }
20//! # );
21//! ```
22//!
23
24use atomic_waker::AtomicWaker;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::{Arc, Weak};
28use std::task::{Context, Poll};
29
30pub struct WaitGroup {
31    inner: Arc<Inner>,
32}
33
34#[derive(Clone)]
35pub struct Worker {
36    inner: Arc<Inner>,
37}
38
39pub struct WaitGroupFuture {
40    inner: Weak<Inner>,
41}
42
43struct Inner {
44    waker: AtomicWaker,
45}
46
47impl Drop for Inner {
48    fn drop(&mut self) {
49        self.waker.wake();
50    }
51}
52
53impl WaitGroup {
54    pub fn new() -> Self {
55        Self {
56            inner: Arc::new(Inner {
57                waker: AtomicWaker::new(),
58            }),
59        }
60    }
61
62    pub fn worker(&self) -> Worker {
63        Worker {
64            inner: self.inner.clone(),
65        }
66    }
67
68    pub fn wait(self) -> WaitGroupFuture {
69        WaitGroupFuture {
70            inner: Arc::downgrade(&self.inner),
71        }
72    }
73}
74
75impl Default for WaitGroup {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81/*
82IntoFuture tracking issue: https://github.com/rust-lang/rust/issues/67644
83impl IntoFuture for WaitGroup {
84    type Output = ();
85    type Future = WaitGroupFuture;
86
87    fn into_future(self) -> Self::Future {
88        WaitGroupFuture { inner: Arc::downgrade(&self.inner) }
89    }
90}
91*/
92
93impl Future for WaitGroupFuture {
94    type Output = ();
95
96    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97        match self.inner.upgrade() {
98            Some(inner) => {
99                inner.waker.register(cx.waker());
100                Poll::Pending
101            }
102            None => Poll::Ready(()),
103        }
104    }
105}
106
107#[cfg(test)]
108mod test {
109    use super::*;
110    use async_std::task;
111
112    #[async_std::test]
113    async fn smoke() {
114        let wg = WaitGroup::new();
115
116        for _ in 0..100 {
117            let w = wg.worker();
118            task::spawn(async move {
119                drop(w);
120            });
121        }
122
123        wg.wait().await;
124    }
125}