async_scoped/
scoped.rs

1use std::marker::PhantomData;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::{AbortHandle, Abortable};
6use futures::stream::FuturesOrdered;
7use futures::{Future, Stream};
8
9use pin_project::*;
10
11use crate::spawner::*;
12
13/// A scope to allow controlled spawning of non 'static
14/// futures. Futures can be spawned using `spawn` or
15/// `spawn_cancellable` methods.
16///
17/// # Safety
18///
19/// This type uses `Drop` implementation to guarantee
20/// safety. It is not safe to forget this object unless it
21/// is driven to completion.
22#[pin_project(PinnedDrop)]
23pub struct Scope<'a, T, Sp: Spawner<T> + Blocker> {
24    spawner: Option<Sp>,
25    len: usize,
26    #[pin]
27    futs: FuturesOrdered<Sp::SpawnHandle>,
28    abort_handles: Vec<AbortHandle>,
29
30    // Future proof against variance changes
31    _marker: PhantomData<fn(&'a ()) -> &'a ()>,
32    _spawn_marker: PhantomData<Sp>,
33}
34
35impl<'a, T: Send + 'static, Sp: Spawner<T> + Blocker> Scope<'a, T, Sp> {
36    /// Create a Scope object.
37    ///
38    /// This function is unsafe as `futs` may hold futures
39    /// which have to be manually driven to completion.
40    pub unsafe fn create(spawner: Sp) -> Self {
41        Scope {
42            spawner: Some(spawner),
43            len: 0,
44            futs: FuturesOrdered::new(),
45            abort_handles: vec![],
46            _marker: PhantomData,
47            _spawn_marker: PhantomData,
48        }
49    }
50
51    fn spawner(&self) -> &Sp {
52        self.spawner
53            .as_ref()
54            .expect("invariant:spawner is always available until scope is dropped")
55    }
56
57    /// Spawn a future with the executor's `task::spawn` functionality. The
58    /// future is expected to be driven to completion before 'a expires.
59    pub fn spawn<F: Future<Output = T> + Send + 'a>(&mut self, f: F) {
60        let handle = self.spawner().spawn(unsafe {
61            std::mem::transmute::<_, Pin<Box<dyn Future<Output = T> + Send>>>(
62                Box::pin(f) as Pin<Box<dyn Future<Output = T>>>
63            )
64        });
65        self.futs.push_back(handle);
66        self.len += 1;
67    }
68
69    /// Spawn a cancellable future with the executor's `task::spawn`
70    /// functionality.
71    ///
72    /// The future is cancelled if the `Scope` is dropped
73    /// pre-maturely. It can also be cancelled by explicitly
74    /// calling (and awaiting) the `cancel` method.
75    #[inline]
76    pub fn spawn_cancellable<F: Future<Output = T> + Send + 'a, Fu: FnOnce() -> T + Send + 'a>(
77        &mut self,
78        f: F,
79        default: Fu,
80    ) {
81        let (h, reg) = AbortHandle::new_pair();
82        self.abort_handles.push(h);
83        let fut = Abortable::new(f, reg);
84        self.spawn(async { fut.await.unwrap_or_else(|_| default()) })
85    }
86
87    /// Spawn a function as a blocking future with executor's `spawn_blocking`
88    /// functionality.
89    ///
90    /// The future is cancelled if the `Scope` is dropped
91    /// pre-maturely. It can also be cancelled by explicitly
92    /// calling (and awaiting) the `cancel` method.
93    pub fn spawn_blocking<F: FnOnce() -> T + Send + 'a>(&mut self, f: F)
94    where
95        Sp: FuncSpawner<T, SpawnHandle = <Sp as Spawner<T>>::SpawnHandle>,
96    {
97        let handle = self.spawner().spawn_func(unsafe {
98            std::mem::transmute::<_, Box<dyn FnOnce() -> T + Send>>(
99                Box::new(f) as Box<dyn FnOnce() -> T + Send>
100            )
101        });
102        self.futs.push_back(handle);
103        self.len += 1;
104    }
105}
106
107impl<'a, T, Sp: Spawner<T> + Blocker> Scope<'a, T, Sp> {
108    /// Cancel all futures spawned with cancellation.
109    #[inline]
110    pub fn cancel(&mut self) {
111        for h in self.abort_handles.drain(..) {
112            h.abort();
113        }
114    }
115
116    /// Total number of futures spawned in this scope.
117    #[inline]
118    pub fn len(&self) -> usize {
119        self.len
120    }
121
122    /// Number of futures remaining in this scope.
123    #[inline]
124    pub fn remaining(&self) -> usize {
125        self.futs.len()
126    }
127
128    /// A slighly optimized `collect` on the stream. Also
129    /// useful when we can not move out of self.
130    pub async fn collect(&mut self) -> Vec<Sp::FutureOutput> {
131        let mut proc_outputs = Vec::with_capacity(self.remaining());
132
133        use futures::StreamExt;
134        while let Some(item) = self.next().await {
135            proc_outputs.push(item);
136        }
137
138        proc_outputs
139    }
140}
141
142impl<'a, T, Sp: Spawner<T> + Blocker> Stream for Scope<'a, T, Sp> {
143    type Item = Sp::FutureOutput;
144
145    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
146        self.project().futs.poll_next(cx)
147    }
148
149    fn size_hint(&self) -> (usize, Option<usize>) {
150        (self.remaining(), Some(self.remaining()))
151    }
152}
153
154#[pinned_drop]
155impl<'a, T, Sp: Spawner<T> + Blocker> PinnedDrop for Scope<'a, T, Sp> {
156    fn drop(mut self: Pin<&mut Self>) {
157        if self.remaining() > 0 {
158            let spawner = self
159                .spawner
160                .take()
161                .expect("invariant:spawner must be taken only on drop");
162            spawner.block_on(async {
163                self.cancel();
164                self.collect().await;
165            });
166        }
167    }
168}