1use atomic_waker::AtomicWaker;
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::{Arc, Weak};
28use std::task::{Context, Poll};
29
30pub struct WaitGroup {
31 inner: Arc<Inner>,
32}
33
34#[derive(Clone)]
35pub struct Worker {
36 inner: Arc<Inner>,
37}
38
39pub struct WaitGroupFuture {
40 inner: Weak<Inner>,
41}
42
43struct Inner {
44 waker: AtomicWaker,
45}
46
47impl Drop for Inner {
48 fn drop(&mut self) {
49 self.waker.wake();
50 }
51}
52
53impl WaitGroup {
54 pub fn new() -> Self {
55 Self {
56 inner: Arc::new(Inner {
57 waker: AtomicWaker::new(),
58 }),
59 }
60 }
61
62 pub fn worker(&self) -> Worker {
63 Worker {
64 inner: self.inner.clone(),
65 }
66 }
67
68 pub fn wait(self) -> WaitGroupFuture {
69 WaitGroupFuture {
70 inner: Arc::downgrade(&self.inner),
71 }
72 }
73}
74
75impl Default for WaitGroup {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl Future for WaitGroupFuture {
94 type Output = ();
95
96 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
97 match self.inner.upgrade() {
98 Some(inner) => {
99 inner.waker.register(cx.waker());
100 Poll::Pending
101 }
102 None => Poll::Ready(()),
103 }
104 }
105}
106
107#[cfg(test)]
108mod test {
109 use super::*;
110 use async_std::task;
111
112 #[async_std::test]
113 async fn smoke() {
114 let wg = WaitGroup::new();
115
116 for _ in 0..100 {
117 let w = wg.worker();
118 task::spawn(async move {
119 drop(w);
120 });
121 }
122
123 wg.wait().await;
124 }
125}